diff options
Diffstat (limited to 'pkg')
232 files changed, 10054 insertions, 3261 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 4a26e28de..a0654df2f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -55,6 +55,8 @@ go_library( "sched.go", "seccomp.go", "sem.go", + "sem_amd64.go", + "sem_arm64.go", "shm.go", "signal.go", "signalfd.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 7df02dd6d..006b5a525 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -121,6 +121,9 @@ const ( // Constants from uapi/linux/fsverity.h. const ( + FS_VERITY_HASH_ALG_SHA256 = 1 + FS_VERITY_HASH_ALG_SHA512 = 2 + FS_IOC_ENABLE_VERITY = 1082156677 FS_IOC_MEASURE_VERITY = 3221513862 ) diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 487a626cc..1b2f76c0b 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -34,18 +34,6 @@ const ( const SEM_UNDO = 0x1000 -// SemidDS is equivalent to struct semid64_ds. -// -// +marshal -type SemidDS struct { - SemPerm IPCPerm - SemOTime TimeT - SemCTime TimeT - SemNSems uint64 - unused3 uint64 - unused4 uint64 -} - // Sembuf is equivalent to struct sembuf. // // +marshal slice:SembufSlice diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go new file mode 100644 index 000000000..ab980cb4f --- /dev/null +++ b/pkg/abi/linux/sem_amd64.go @@ -0,0 +1,33 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: arch/x86/include/uapi/asm/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + unused1 uint64 + SemCTime TimeT + unused2 uint64 + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go new file mode 100644 index 000000000..521468fb1 --- /dev/null +++ b/pkg/abi/linux/sem_arm64.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: include/uapi/asm-generic/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + SemCTime TimeT + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/context/context.go b/pkg/context/context.go index 2613bc752..f3031fc60 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()} func Background() Context { return bgContext } + +// WithValue returns a copy of parent in which the value associated with key is +// val. +func WithValue(parent Context, key, val interface{}) Context { + return &withValue{ + Context: parent, + key: key, + val: val, + } +} + +type withValue struct { + Context + key interface{} + val interface{} +} + +// Value implements Context.Value. +func (ctx *withValue) Value(key interface{}) interface{} { + if key == ctx.key { + return ctx.val + } + return ctx.Context.Value(key) +} diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index bee28b68d..a493e3407 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -6,6 +6,7 @@ go_library( name = "eventchannel", srcs = [ "event.go", + "event_any.go", "rate.go", ], visibility = ["//:sandbox"], @@ -14,8 +15,9 @@ go_library( "//pkg/log", "//pkg/sync", "//pkg/unet", - "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_golang_protobuf//ptypes:go_default_library_gen", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//types/known/anypb:go_default_library", "@org_golang_x_time//rate:go_default_library", ], ) @@ -32,6 +34,6 @@ go_test( library = ":eventchannel", deps = [ "//pkg/sync", - "@com_github_golang_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", ], ) diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go index 9a29c58bd..7172ce75d 100644 --- a/pkg/eventchannel/event.go +++ b/pkg/eventchannel/event.go @@ -24,8 +24,8 @@ import ( "fmt" "syscall" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" @@ -118,22 +118,6 @@ func (me *multiEmitter) Close() error { return err } -func marshal(msg proto.Message) ([]byte, error) { - anypb, err := ptypes.MarshalAny(msg) - if err != nil { - return nil, err - } - - // Wire format is uvarint message length followed by binary proto. - bufMsg, err := proto.Marshal(anypb) - if err != nil { - return nil, err - } - p := make([]byte, binary.MaxVarintLen64) - n := binary.PutUvarint(p, uint64(len(bufMsg))) - return append(p[:n], bufMsg...), nil -} - // socketEmitter emits proto messages on a socket. type socketEmitter struct { socket *unet.Socket @@ -155,10 +139,19 @@ func SocketEmitter(fd int) (Emitter, error) { // Emit implements Emitter.Emit. func (s *socketEmitter) Emit(msg proto.Message) (bool, error) { - p, err := marshal(msg) + any, err := newAny(msg) if err != nil { return false, err } + bufMsg, err := proto.Marshal(any) + if err != nil { + return false, err + } + + // Wire format is uvarint message length followed by binary proto. + p := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(p, uint64(len(bufMsg))) + p = append(p[:n], bufMsg...) for done := 0; done < len(p); { n, err := s.socket.Write(p[done:]) if err != nil { @@ -166,6 +159,7 @@ func (s *socketEmitter) Emit(msg proto.Message) (bool, error) { } done += n } + return false, nil } @@ -189,9 +183,13 @@ func DebugEmitterFrom(inner Emitter) Emitter { } func (d *debugEmitter) Emit(msg proto.Message) (bool, error) { + text, err := prototext.Marshal(msg) + if err != nil { + return false, err + } ev := &pb.DebugEvent{ - Name: proto.MessageName(msg), - Text: proto.MarshalTextString(msg), + Name: string(msg.ProtoReflect().Descriptor().FullName()), + Text: string(text), } return d.inner.Emit(ev) } diff --git a/pkg/eventchannel/event.proto b/pkg/eventchannel/event.proto index 34468f072..4b24ac47c 100644 --- a/pkg/eventchannel/event.proto +++ b/pkg/eventchannel/event.proto @@ -16,7 +16,7 @@ syntax = "proto3"; package gvisor; -// A debug event encapsulates any other event protobuf in text format. This is +// DebugEvent encapsulates any other event protobuf in text format. This is // useful because clients reading events emitted this way do not need to link // the event protobufs to display them in a human-readable format. message DebugEvent { diff --git a/pkg/eventchannel/event_any.go b/pkg/eventchannel/event_any.go new file mode 100644 index 000000000..a5549f6cd --- /dev/null +++ b/pkg/eventchannel/event_any.go @@ -0,0 +1,25 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eventchannel + +import ( + "google.golang.org/protobuf/types/known/anypb" + + "google.golang.org/protobuf/proto" +) + +func newAny(m proto.Message) (*anypb.Any, error) { + return anypb.New(m) +} diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 43750360b..0dd408f76 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -19,7 +19,7 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/proto" "gvisor.dev/gvisor/pkg/sync" ) diff --git a/pkg/eventchannel/rate.go b/pkg/eventchannel/rate.go index 179226c92..74960e16a 100644 --- a/pkg/eventchannel/rate.go +++ b/pkg/eventchannel/rate.go @@ -15,8 +15,8 @@ package eventchannel import ( - "github.com/golang/protobuf/proto" "golang.org/x/time/rate" + "google.golang.org/protobuf/proto" ) // rateLimitedEmitter wraps an emitter and limits events to the given limits. diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD index a8fcb2e19..501a9ef21 100644 --- a/pkg/merkletree/BUILD +++ b/pkg/merkletree/BUILD @@ -6,12 +6,18 @@ go_library( name = "merkletree", srcs = ["merkletree.go"], visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) go_test( name = "merkletree_test", srcs = ["merkletree_test.go"], library = ":merkletree", - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index d8227b8bd..18457d287 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -18,21 +18,32 @@ package merkletree import ( "bytes" "crypto/sha256" + "crypto/sha512" "fmt" "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) const ( // sha256DigestSize specifies the digest size of a SHA256 hash. sha256DigestSize = 32 + // sha512DigestSize specifies the digest size of a SHA512 hash. + sha512DigestSize = 64 ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). -func DigestSize() int { - return sha256DigestSize +// TODO(b/156980949): Allow config SHA384. +func DigestSize(hashAlgorithm int) int { + switch hashAlgorithm { + case linux.FS_VERITY_HASH_ALG_SHA256: + return sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + return sha512DigestSize + default: + return -1 + } } // Layout defines the scale of a Merkle tree. @@ -51,11 +62,19 @@ type Layout struct { // InitLayout initializes and returns a new Layout object describing the structure // of a tree. dataSize specifies the size of input data in bytes. -func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { +func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) { layout := Layout{ blockSize: usermem.PageSize, - // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). - digestSize: sha256DigestSize, + } + + // TODO(b/156980949): Allow config SHA384. + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + layout.digestSize = sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + layout.digestSize = sha512DigestSize + default: + return Layout{}, fmt.Errorf("unexpected hash algorithms") } // treeStart is the offset (in bytes) of the first level of the tree in @@ -88,7 +107,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { } layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) - return layout + return layout, nil } // hashesPerBlock() returns the number of digests in each block. For example, @@ -139,12 +158,33 @@ func (d *VerityDescriptor) String() string { } // verify generates a hash from d, and compares it with expected. -func (d *VerityDescriptor) verify(expected []byte) error { - h := sha256.Sum256([]byte(d.String())) +func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error { + h, err := hashData([]byte(d.String()), hashAlgorithms) + if err != nil { + return err + } if !bytes.Equal(h[:], expected) { return fmt.Errorf("unexpected root hash") } return nil + +} + +// hashData hashes data and returns the result hash based on the hash +// algorithms. +func hashData(data []byte, hashAlgorithms int) ([]byte, error) { + var digest []byte + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + digestArray := sha256.Sum256(data) + digest = digestArray[:] + case linux.FS_VERITY_HASH_ALG_SHA512: + digestArray := sha512.Sum512(data) + digest = digestArray[:] + default: + return nil, fmt.Errorf("unexpected hash algorithms") + } + return digest, nil } // GenerateParams contains the parameters used to generate a Merkle tree. @@ -161,6 +201,8 @@ type GenerateParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // TreeReader is a reader for the Merkle tree. TreeReader io.ReaderAt // TreeWriter is a writer for the Merkle tree. @@ -176,7 +218,10 @@ type GenerateParams struct { // Generate returns a hash of a VerityDescriptor, which contains the file // metadata and the hash from file content. func Generate(params *GenerateParams) ([]byte, error) { - layout := InitLayout(params.Size, params.DataAndTreeInSameFile) + layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return nil, err + } numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize @@ -218,10 +263,13 @@ func Generate(params *GenerateParams) ([]byte, error) { return nil, err } // Hash the bytes in buf. - digest := sha256.Sum256(buf) + digest, err := hashData(buf, params.HashAlgorithms) + if err != nil { + return nil, err + } if level == layout.rootLevel() { - root = digest[:] + root = digest } // Write the generated hash to the end of the tree file. @@ -246,8 +294,7 @@ func Generate(params *GenerateParams) ([]byte, error) { GID: params.GID, RootHash: root, } - ret := sha256.Sum256([]byte(descriptor.String())) - return ret[:], nil + return hashData([]byte(descriptor.String()), params.HashAlgorithms) } // VerifyParams contains the params used to verify a portion of a file against @@ -269,6 +316,8 @@ type VerifyParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // ReadOffset is the offset of the data range to be verified. ReadOffset int64 // ReadSize is the size of the data range to be verified. @@ -298,7 +347,7 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error { GID: params.GID, RootHash: root, } - return descriptor.verify(params.Expected) + return descriptor.verify(params.Expected, params.HashAlgorithms) } // Verify verifies the content read from data with offset. The content is @@ -313,7 +362,10 @@ func Verify(params *VerifyParams) (int64, error) { if params.ReadSize < 0 { return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize) } - layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile) + layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return 0, err + } if params.ReadSize == 0 { return 0, verifyMetadata(params, &layout) } @@ -354,7 +406,7 @@ func Verify(params *VerifyParams) (int64, error) { UID: params.UID, GID: params.GID, } - if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil { + if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { return 0, err } @@ -395,7 +447,7 @@ func Verify(params *VerifyParams) (int64, error) { // fails if the calculated hash from block is different from any level of // hashes stored in tree. And the final root hash is compared with // expected. -func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error { +func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error { if len(dataBlock) != int(layout.blockSize) { return fmt.Errorf("incorrect block size") } @@ -406,8 +458,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, for level := 0; level < layout.numLevels(); level++ { // Calculate hash. if level == 0 { - digestArray := sha256.Sum256(dataBlock) - digest = digestArray[:] + h, err := hashData(dataBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } else { // Read a block in previous level that contains the // hash we just generated, and generate a next level @@ -415,8 +470,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil { return err } - digestArray := sha256.Sum256(treeBlock) - digest = digestArray[:] + h, err := hashData(treeBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } // Read the digest for the current block and store in @@ -434,5 +492,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, // Verification for the tree succeeded. Now hash the descriptor with // the root hash and compare it with expected. descriptor.RootHash = digest - return descriptor.verify(expected) + return descriptor.verify(expected, hashAlgorithms) } diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index e1350ebda..0782ca3e7 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -22,54 +22,114 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) func TestLayout(t *testing.T) { testCases := []struct { dataSize int64 + hashAlgorithms int dataAndTreeInSameFile bool + expectedDigestSize int64 expectedLevelOffset []int64 }{ { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0}, }, { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, + expectedLevelOffset: []int64{usermem.PageSize}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, expectedLevelOffset: []int64{usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize}, + }, + { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, }, { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize}, + }, + { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, }, { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize}, + }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize}, + }, } for _, tc := range testCases { t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) + l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile) + if err != nil { + t.Fatalf("Failed to InitLayout: %v", err) + } if l.blockSize != int64(usermem.PageSize) { t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } - if l.digestSize != sha256DigestSize { + if l.digestSize != tc.expectedDigestSize { t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } if l.numLevels() != len(tc.expectedLevelOffset) { @@ -118,24 +178,49 @@ func TestGenerate(t *testing.T) { // The input data has size dataSize. It starts with the data in startWith, // and all other bytes are zeroes. testCases := []struct { - data []byte - expectedHash []byte + data []byte + hashAlgorithms int + expectedHash []byte }{ { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + }, + { + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{14, 27, 126, 158, 9, 94, 163, 51, 243, 162, 82, 167, 183, 127, 93, 121, 221, 23, 184, 59, 104, 166, 111, 49, 161, 195, 229, 111, 121, 201, 233, 68, 10, 154, 78, 142, 154, 236, 170, 156, 110, 167, 15, 144, 155, 97, 241, 235, 202, 233, 246, 217, 138, 88, 152, 179, 238, 46, 247, 185, 125, 20, 101, 201}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{55, 204, 240, 1, 224, 252, 58, 131, 251, 174, 45, 140, 107, 57, 118, 11, 18, 236, 203, 204, 19, 59, 27, 196, 3, 78, 21, 7, 22, 98, 197, 128, 17, 128, 90, 122, 54, 83, 253, 108, 156, 67, 59, 229, 236, 241, 69, 88, 99, 44, 127, 109, 204, 183, 150, 232, 187, 57, 228, 137, 209, 235, 241, 172}, }, { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, }, { - data: []byte{'a'}, - expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{207, 233, 114, 94, 113, 212, 243, 160, 59, 232, 226, 77, 28, 81, 176, 61, 211, 213, 222, 190, 148, 196, 90, 166, 237, 56, 113, 148, 230, 154, 23, 105, 14, 97, 144, 211, 12, 122, 226, 207, 167, 203, 136, 193, 38, 249, 227, 187, 92, 238, 101, 97, 170, 255, 246, 209, 246, 98, 241, 150, 175, 253, 173, 206}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + }, + { + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{110, 103, 29, 250, 27, 211, 235, 119, 112, 65, 49, 156, 6, 92, 66, 105, 133, 1, 187, 172, 169, 13, 186, 34, 105, 72, 252, 131, 12, 159, 91, 188, 79, 184, 240, 227, 40, 164, 72, 193, 65, 31, 227, 153, 191, 6, 117, 42, 82, 122, 33, 255, 92, 215, 215, 249, 2, 131, 170, 134, 39, 192, 222, 33}, }, } @@ -149,6 +234,7 @@ func TestGenerate(t *testing.T) { Mode: defaultMode, UID: defaultUID, GID: defaultGID, + HashAlgorithms: tc.hashAlgorithms, TreeReader: &tree, TreeWriter: &tree, DataAndTreeInSameFile: dataAndTreeInSameFile, @@ -348,77 +434,81 @@ func TestVerify(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Flip a bit in data and checks Verify results. - var buf bytes.Buffer - data[tc.modifyByte] ^= 1 - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: tc.dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: tc.verifyStart, - ReadSize: tc.verifySize, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if tc.modifyName { - verifyParams.Name = defaultName + "abc" - } - if tc.modifyMode { - verifyParams.Mode = defaultMode + 1 - } - if tc.modifyUID { - verifyParams.UID = defaultUID + 1 - } - if tc.modifyGID { - verifyParams.GID = defaultGID + 1 - } - if tc.shouldSucceed { - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed when expected to succeed: %v", err) + // Flip a bit in data and checks Verify results. + var buf bytes.Buffer + data[tc.modifyByte] ^= 1 + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: tc.dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: tc.verifyStart, + ReadSize: tc.verifySize, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, } - if n != tc.verifySize { - t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + if tc.modifyName { + verifyParams.Name = defaultName + "abc" } - if int64(buf.Len()) != tc.verifySize { - t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + if tc.modifyMode { + verifyParams.Mode = defaultMode + 1 } - if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") + if tc.modifyUID { + verifyParams.UID = defaultUID + 1 } - } else { - if _, err := Verify(&verifyParams); err == nil { - t.Errorf("Verification succeeded when expected to fail") + if tc.modifyGID { + verifyParams.GID = defaultGID + 1 + } + if tc.shouldSucceed { + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + } else { + if _, err := Verify(&verifyParams); err == nil { + t.Errorf("Verification succeeded when expected to fail") + } } } } @@ -435,87 +525,91 @@ func TestVerifyRandom(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Pick a random portion of data. - start := rand.Int63n(dataSize - 1) - size := rand.Int63n(dataSize) + 1 + // Pick a random portion of data. + start := rand.Int63n(dataSize - 1) + size := rand.Int63n(dataSize) + 1 - var buf bytes.Buffer - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: start, - ReadSize: size, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + var buf bytes.Buffer + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: start, + ReadSize: size, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - // Checks that the random portion of data from the original data is - // verified successfully. - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed for correct data: %v", err) - } - if size > dataSize-start { - size = dataSize - start - } - if n != size { - t.Errorf("Got Verify output size %d, want %d", n, size) - } - if int64(buf.Len()) != size { - t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) - } - if !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } + // Checks that the random portion of data from the original data is + // verified successfully. + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed for correct data: %v", err) + } + if size > dataSize-start { + size = dataSize - start + } + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } - // Verify that modified metadata should fail verification. - buf.Reset() - verifyParams.Name = defaultName + "abc" - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verify succeeded for modified metadata, expect failure") - } + // Verify that modified metadata should fail verification. + buf.Reset() + verifyParams.Name = defaultName + "abc" + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verify succeeded for modified metadata, expect failure") + } - // Flip a random bit in randPortion, and check that verification fails. - buf.Reset() - randBytePos := rand.Int63n(size) - data[start+randBytePos] ^= 1 - verifyParams.File = bytes.NewReader(data) - verifyParams.Name = defaultName + // Flip a random bit in randPortion, and check that verification fails. + buf.Reset() + randBytePos := rand.Int63n(size) + data[start+randBytePos] ^= 1 + verifyParams.File = bytes.NewReader(data) + verifyParams.Name = defaultName - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verification succeeded for modified data, expect failure") + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verification succeeded for modified data, expect failure") + } } } } diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 58305009d..0a6a5d215 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -27,6 +27,6 @@ go_test( deps = [ ":metric_go_proto", "//pkg/eventchannel", - "@com_github_golang_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", ], ) diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index c425ea532..aefd0ea5c 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -17,7 +17,7 @@ package metric import ( "testing" - "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/proto" "gvisor.dev/gvisor/pkg/eventchannel" pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" ) diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD index 577b827a5..245e33d2d 100644 --- a/pkg/refs_vfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -19,8 +19,16 @@ go_template( ) go_library( - name = "refs_vfs2", - srcs = ["refs.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/context"], + name = "refsvfs2", + srcs = [ + "refs.go", + "refs_map.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/context", + "//pkg/log", + "//pkg/refs", + "//pkg/sync", + ], ) diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go index 99a074e96..ef8beb659 100644 --- a/pkg/refs_vfs2/refs.go +++ b/pkg/refsvfs2/refs.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package refs_vfs2 defines an interface for a reference-counted object. -package refs_vfs2 +// Package refsvfs2 defines an interface for a reference-counted object. +package refsvfs2 import ( "gvisor.dev/gvisor/pkg/context" diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go new file mode 100644 index 000000000..be75b0cc2 --- /dev/null +++ b/pkg/refsvfs2/refs_map.go @@ -0,0 +1,97 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package refsvfs2 + +import ( + "fmt" + "strings" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" +) + +// TODO(gvisor.dev/issue/1193): re-enable once kernfs refs are fixed. +var ignored []string = []string{"kernfs.", "proc.", "sys.", "devpts.", "fuse."} + +var ( + // liveObjects is a global map of reference-counted objects. Objects are + // inserted when leak check is enabled, and they are removed when they are + // destroyed. It is protected by liveObjectsMu. + liveObjects map[CheckedObject]struct{} + liveObjectsMu sync.Mutex +) + +// CheckedObject represents a reference-counted object with an informative +// leak detection message. +type CheckedObject interface { + // LeakMessage supplies a warning to be printed upon leak detection. + LeakMessage() string +} + +func init() { + liveObjects = make(map[CheckedObject]struct{}) +} + +// LeakCheckEnabled returns whether leak checking is enabled. The following +// functions should only be called if it returns true. +func LeakCheckEnabled() bool { + return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking +} + +// Register adds obj to the live object map. +func Register(obj CheckedObject, typ string) { + for _, str := range ignored { + if strings.Contains(typ, str) { + return + } + } + liveObjectsMu.Lock() + if _, ok := liveObjects[obj]; ok { + panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj)) + } + liveObjects[obj] = struct{}{} + liveObjectsMu.Unlock() +} + +// Unregister removes obj from the live object map. +func Unregister(obj CheckedObject, typ string) { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + if _, ok := liveObjects[obj]; !ok { + for _, str := range ignored { + if strings.Contains(typ, str) { + return + } + } + panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj)) + } + delete(liveObjects, obj) +} + +// DoLeakCheck iterates through the live object map and logs a message for each +// object. It is called once no reference-counted objects should be reachable +// anymore, at which point anything left in the map is considered a leak. +func DoLeakCheck() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + leaked := len(liveObjects) + if leaked > 0 { + log.Warningf("Leak checking detected %d leaked objects:", leaked) + for obj := range liveObjects { + log.Warningf(obj.LeakMessage()) + } + } +} diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index d9b552896..ec295ef5b 100644 --- a/pkg/refs_vfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -21,11 +21,9 @@ package refs_template import ( "fmt" - "runtime" "sync/atomic" - "gvisor.dev/gvisor/pkg/log" - refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" ) // T is the type of the reference counted object. It is only used to customize @@ -42,11 +40,6 @@ var ownerType *T // Note that the number of references is actually refCount + 1 so that a default // zero-value Refs object contains one reference. // -// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in -// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. -// This will allow us to add stack trace information to the leak messages -// without growing the size of Refs. -// // +stateify savable type Refs struct { // refCount is composed of two fields: @@ -59,24 +52,16 @@ type Refs struct { refCount int64 } -func (r *Refs) finalize() { - var note string - switch refs_vfs1.GetLeakMode() { - case refs_vfs1.NoLeakChecking: - return - case refs_vfs1.UninitializedLeakChecking: - note = "(Leak checker uninitialized): " - } - if n := r.ReadRefs(); n != 0 { - log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) +// EnableLeakCheck enables reference leak checking on r. +func (r *Refs) EnableLeakCheck() { + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(r, fmt.Sprintf("%T", ownerType)) } } -// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. -func (r *Refs) EnableLeakCheck() { - if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { - runtime.SetFinalizer(r, (*Refs).finalize) - } +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *Refs) LeakMessage() string { + return fmt.Sprintf("%T %p: reference count of %d instead of 0", ownerType, r, r.ReadRefs()) } // ReadRefs returns the current number of references. The returned count is @@ -91,7 +76,7 @@ func (r *Refs) ReadRefs() int64 { //go:nosplit func (r *Refs) IncRef() { if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { - panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType)) + panic(fmt.Sprintf("Incrementing non-positive count %p on %T", r, ownerType)) } } @@ -134,9 +119,18 @@ func (r *Refs) DecRef(destroy func()) { panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType)) case v == -1: + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(r, fmt.Sprintf("%T", ownerType)) + } // Call the destructor. if destroy != nil { destroy() } } } + +func (r *Refs) afterLoad() { + if refsvfs2.LeakCheckEnabled() && r.ReadRefs() > 0 { + r.EnableLeakCheck() + } +} diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index 41feeffe3..d800f2c85 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -69,5 +69,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { s.Kernel.Kill(kernel.ExitStatus{}) }, } - return saveOpts.Save(s.Kernel, s.Watchdog) + return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog) } diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 655ea549b..ff5d49fbd 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -39,6 +39,8 @@ const ( ) // tunDevice implements vfs.Device for /dev/net/tun. +// +// +stateify savable type tunDevice struct{} // Open implements vfs.Device.Open. @@ -53,6 +55,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt } // tunFD implements vfs.FileDescriptionImpl for /dev/net/tun. +// +// +stateify savable type tunFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 1390a9a7f..4468f5dd2 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() { f.mappings = make(map[uint64]mapping) } +// IsInited returns true if f.Init() has been called. This is used when +// restoring a checkpoint that contains a HostFileMapper that may or may not +// have been initialized. +func (f *HostFileMapper) IsInited() bool { + return f.refs != nil +} + // NewHostFileMapper returns an initialized HostFileMapper allocated on the // heap with no references or cached mappings. func NewHostFileMapper() *HostFileMapper { diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 22d658acf..450044c9c 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -92,6 +92,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo "gid_map": newGIDMap(t, msrc), "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), + "mem": newMem(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "net": newNetDir(t, msrc), @@ -399,6 +400,88 @@ func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { return newProcInode(t, d, msrc, fs.SpecialDirectory, t) } +// memData implements fs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memData struct { + fsutil.SimpleFileInode + + t *kernel.Task +} + +// memDataFile implements fs.FileOperations for /proc/[pid]/mem. +// +// +stateify savable +type memDataFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + t *kernel.Task +} + +func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + inode := &memData{ + SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC), + t: t, + } + return newProcInode(t, inode, msrc, fs.SpecialFile, t) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (m *memData) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, m.t, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(m.t); err != nil { + return nil, err + } + // Enable random access reads + flags.Pread = true + return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil +} + +// Read implements fs.FileOperations.Read. +func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + mm, err := getTaskMM(m.t) + if err != nil { + return 0, nil + } + defer mm.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + // mapsData implements seqfile.SeqSource for /proc/[pid]/maps. // // +stateify savable diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 84baaac66..6af3c3781 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "root_inode_refs.go", package = "devpts", prefix = "rootInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "rootInode", }, @@ -33,6 +33,7 @@ go_library( "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index d5c5aaa8c..9185877f6 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir } fstype.initOnce.Do(func() { - fs, root, err := fstype.newFilesystem(vfsObj, creds) + fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds) if err != nil { fstype.initErr = err return @@ -93,7 +93,7 @@ type filesystem struct { // newFilesystem creates a new devpts filesystem with root directory and ptmx // master inode. It returns the filesystem and root Dentry. -func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { +func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err @@ -108,7 +108,7 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds root := &rootInode{ replicas: make(map[uint32]*replicaInode), } - root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) + root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) root.EnableLeakCheck() @@ -120,7 +120,7 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds master := &masterInode{ root: root, } - master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) + master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) // Add the master as a child of the root. links := root.OrderedChildren.Populate(map[string]kernfs.Inode{ @@ -170,7 +170,7 @@ type rootInode struct { var _ kernfs.Inode = (*rootInode)(nil) // allocateTerminal creates a new Terminal and installs a pts node for it. -func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) { +func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) { i.mu.Lock() defer i.mu.Unlock() if i.nextIdx == math.MaxUint32 { @@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) i.replicas[idx] = replica return t, nil @@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro } // IterDirents implements kernfs.Inode.IterDirents. -func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() + i.InodeAttrs.TouchAtime(ctx, mnt) ids := make([]int, 0, len(i.replicas)) for id := range i.replicas { ids = append(ids, int(id)) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index e6b0e81cf..ae95fdd08 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -100,10 +100,10 @@ type lineDiscipline struct { column int // masterWaiter is used to wait on the master end of the TTY. - masterWaiter waiter.Queue `state:"zerovalue"` + masterWaiter waiter.Queue // replicaWaiter is used to wait on the replica end of the TTY. - replicaWaiter waiter.Queue `state:"zerovalue"` + replicaWaiter waiter.Queue } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index fda30fb93..e91fa26a4 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil) // Open implements kernfs.Inode.Open. func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - t, err := mi.root.allocateTerminal(rp.Credentials()) + t, err := mi.root.allocateTerminal(ctx, rp.Credentials()) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD index 01bbee5ad..e49a04c1b 100644 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ b/pkg/sentry/fsimpl/devtmpfs/BUILD @@ -4,7 +4,10 @@ licenses(["notice"]) go_library( name = "devtmpfs", - srcs = ["devtmpfs.go"], + srcs = [ + "devtmpfs.go", + "save_restore.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/devtmpfs/save_restore.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go new file mode 100644 index 000000000..28832d850 --- /dev/null +++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devtmpfs + +// afterLoad is invoked by stateify. +func (fst *FilesystemType) afterLoad() { + if fst.fs != nil { + // Ensure that we don't create another filesystem. + fst.initOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 1c27ad700..5b29f2358 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -43,7 +43,7 @@ type EventFileDescription struct { // queue is used to notify interested parties when the event object // becomes readable or writable. - queue waiter.Queue `state:"zerovalue"` + queue waiter.Queue // mu protects the fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 045d7ab08..2158b1bbc 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "inode_refs.go", package = "fuse", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -49,6 +49,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/kernfs", diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index e39df21c6..e7ef5998e 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -205,7 +205,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // root is the fusefs root directory. - root := fs.newRootInode(creds, fsopts.rootMode) + root := fs.newRootInode(ctx, creds, fsopts.rootMode) return fs.VFSFilesystem(), root.VFSDentry(), nil } @@ -284,9 +284,9 @@ type inode struct { link string } -func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { +func (fs *filesystem) newRootInode(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { i := &inode{fs: fs, nodeID: 1} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) + i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() @@ -295,10 +295,10 @@ func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) return &d } -func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { +func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { i := &inode{fs: fs, nodeID: nodeID} creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} - i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) atomic.StoreUint64(&i.size, attr.Size) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() @@ -424,7 +424,7 @@ func (i *inode) Keep() bool { } // IterDirents implements kernfs.Inode.IterDirents. -func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { return offset, nil } @@ -544,7 +544,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { return nil, syserror.EIO } - child := i.fs.newInode(out.NodeID, out.Attr) + child := i.fs.newInode(ctx, out.NodeID, out.Attr) return child, nil } @@ -696,7 +696,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return linux.FUSEAttr{}, err @@ -812,7 +812,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 625d1547f..2d396e84c 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -132,7 +132,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u // May need to update the signature. i := fd.inode() - // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount()) // Reached EOF. if sizeRead < size { @@ -179,6 +179,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, Flags: fd.statusFlags(), } + inode := fd.inode() var written uint32 // This loop is intended for fragmented write where the bytes to write is @@ -203,7 +204,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, in.Offset = off + uint64(written) in.Size = toWrite - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) if err != nil { return 0, err } @@ -237,6 +238,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, break } } + inode.InodeAttrs.TouchCMtime(ctx) return written, nil } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index ad0afc41b..4c3e9acf8 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "save_restore.go", "socket.go", "special_file.go", "symlink.go", @@ -53,6 +54,7 @@ go_library( "//pkg/log", "//pkg/p9", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", @@ -70,6 +72,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 18c884b59..e993c8e36 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -16,16 +16,17 @@ package gofer import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -92,7 +93,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { child := &dentry{ refs: 1, // held by d fs: d.fs, - ino: d.fs.nextSyntheticIno(), + ino: d.fs.nextIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), @@ -100,6 +101,9 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { hostFD: -1, nlink: uint32(2), } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(child, "gofer.dentry") + } switch opts.mode.FileType() { case linux.S_IFDIR: // Nothing else needs to be done. @@ -235,7 +239,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: uint64(inoFromPath(p9d.QID.Path)), + Ino: d.fs.inoFromQIDPath(p9d.QID.Path), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 94d96261b..baecb88c4 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -35,7 +35,7 @@ import ( // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { - // Snapshot current syncable dentries and special files. + // Snapshot current syncable dentries and special file FDs. fs.syncMu.Lock() ds := make([]*dentry, 0, len(fs.syncableDentries)) for d := range fs.syncableDentries { @@ -53,22 +53,28 @@ func (fs *filesystem) Sync(ctx context.Context) error { // regardless. var retErr error - // Sync regular files. + // Sync syncable dentries. for _, d := range ds { - err := d.syncCachedFile(ctx) + err := d.syncCachedFile(ctx, true /* forFilesystemSync */) d.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) + if retErr == nil { + retErr = err + } } } // Sync special files, which may be writable but do not use dentry shared // handles (so they won't be synced by the above). for _, sffd := range sffds { - err := sffd.Sync(ctx) + err := sffd.sync(ctx, true /* forFilesystemSync */) sffd.vfsfd.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) + if retErr == nil { + retErr = err + } } } @@ -229,7 +235,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir return nil, err } if child != nil { - if !file.isNil() && inoFromPath(qid.Path) == child.ino { + if !file.isNil() && qid.Path == child.qidPath { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) child.updateFromP9AttrsLocked(attrMask, &attr) @@ -1512,7 +1518,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath d.IncRef() return &endpoint{ dentry: d, - file: d.file.file, path: opts.Addr, }, nil } @@ -1591,7 +1596,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } - -func (fs *filesystem) nextSyntheticIno() inodeNumber { - return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) -} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index f1dad1b08..80668ebc1 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -26,6 +26,9 @@ // *** "memmap.Mappable locks taken by Translate" below this point // dentry.handleMu // dentry.dataMu +// filesystem.inoMu +// specialFileFD.mu +// specialFileFD.bufMu // // Locking dentry.dirMu in multiple dentries requires that either ancestor // dentries are locked before descendant dentries, or that filesystem.renameMu @@ -36,7 +39,6 @@ import ( "fmt" "strconv" "strings" - "sync" "sync/atomic" "syscall" @@ -44,6 +46,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -53,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -81,7 +86,7 @@ type filesystem struct { iopts InternalFilesystemOptions // client is the client used by this filesystem. client is immutable. - client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + client *p9.Client `state:"nosave"` // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -89,6 +94,9 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 + // root is the root dentry. root is immutable. + root *dentry + // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -103,39 +111,35 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) - // cachedDentriesLen is the number of dentries in cachedDentries. These - // fields are protected by renameMu. + // cachedDentriesLen is the number of dentries in cachedDentries. These fields + // are protected by renameMu. cachedDentries dentryList cachedDentriesLen uint64 - // syncableDentries contains all dentries in this filesystem for which - // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs. - // These fields are protected by syncMu. + // syncableDentries contains all non-synthetic dentries. specialFileFDs + // contains all open specialFileFDs. These fields are protected by syncMu. syncMu sync.Mutex `state:"nosave"` syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} - // syntheticSeq stores a counter to used to generate unique inodeNumber for - // synthetic dentries. - syntheticSeq uint64 -} + // inoByQIDPath maps previously-observed QID.Paths to inode numbers + // assigned to those paths. inoByQIDPath is not preserved across + // checkpoint/restore because QIDs may be reused between different gofer + // processes, so QIDs may be repeated for different files across + // checkpoint/restore. inoByQIDPath is protected by inoMu. + inoMu sync.Mutex `state:"nosave"` + inoByQIDPath map[uint64]uint64 `state:"nosave"` -// inodeNumber represents inode number reported in Dirent.Ino. For regular -// dentries, it comes from QID.Path from the 9P server. Synthetic dentries -// have have their inodeNumber generated sequentially, with the MSB reserved to -// prevent conflicts with regular dentries. -// -// +stateify savable -type inodeNumber uint64 + // lastIno is the last inode number assigned to a file. lastIno is accessed + // using atomic memory operations. + lastIno uint64 -// Reserve MSB for synthetic mounts. -const syntheticInoMask = uint64(1) << 63 + // savedDentryRW records open read/write handles during save/restore. + savedDentryRW map[*dentry]savedDentryRW -func inoFromPath(path uint64) inodeNumber { - if path&syntheticInoMask != 0 { - log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) - } - return inodeNumber(path &^ syntheticInoMask) + // released is nonzero once filesystem.Release has been called. It is accessed + // with atomic memory operations. + released int32 } // +stateify savable @@ -149,8 +153,7 @@ type filesystemOptions struct { msize uint32 version string - // maxCachedDentries is the maximum number of dentries with 0 references - // retained by the client. + // maxCachedDentries is the maximum size of filesystem.cachedDentries. maxCachedDentries uint64 // If forcePageCache is true, host FDs may not be used for application @@ -247,6 +250,10 @@ const ( // // +stateify savable type InternalFilesystemOptions struct { + // If UniqueID is non-empty, it is an opaque string used to reassociate the + // filesystem with a new server FD during restoration from checkpoint. + UniqueID string + // If LeakConnection is true, do not close the connection to the server // when the Filesystem is released. This is necessary for deployments in // which servers can handle only a single client and report failure if that @@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mopts := vfs.GenericParseMountOptions(opts.Data) var fsopts filesystemOptions - // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'") - return nil, nil, syserror.EINVAL - } - delete(mopts, "trans") - if trans != "fd" { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans) - return nil, nil, syserror.EINVAL - } - - // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "rfdno") - rfd, err := strconv.Atoi(rfdstr) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr) - return nil, nil, syserror.EINVAL - } - wfdstr, ok := mopts["wfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "wfdno") - wfd, err := strconv.Atoi(wfdstr) + fd, err := getFDFromMountOptionsMap(ctx, mopts) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr) - return nil, nil, syserror.EINVAL - } - if rfd != wfd { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd) - return nil, nil, syserror.EINVAL + return nil, nil, err } - fsopts.fd = rfd + fsopts.fd = fd // Get the attach name. fsopts.aname = "/" @@ -441,57 +413,44 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // If !ok, iopts being the zero value is correct. - // Establish a connection with the server. - conn, err := unet.NewSocket(fsopts.fd) + // Construct the filesystem object. + devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } + fs := &filesystem{ + mfp: mfp, + opts: fsopts, + iopts: iopts, + clock: ktime.RealtimeClockFromContext(ctx), + devMinor: devMinor, + syncableDentries: make(map[*dentry]struct{}), + specialFileFDs: make(map[*specialFileFD]struct{}), + inoByQIDPath: make(map[uint64]uint64), + } + fs.vfsfs.Init(vfsObj, &fstype, fs) - // Perform version negotiation with the server. - ctx.UninterruptibleSleepStart(false) - client, err := p9.NewClient(conn, fsopts.msize, fsopts.version) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - conn.Close() + // Connect to the server. + if err := fs.dial(ctx); err != nil { return nil, nil, err } - // Ownership of conn has been transferred to client. // Perform attach to obtain the filesystem root. ctx.UninterruptibleSleepStart(false) - attached, err := client.Attach(fsopts.aname) + attached, err := fs.client.Attach(fsopts.aname) ctx.UninterruptibleSleepFinish(false) if err != nil { - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } attachFile := p9file{attached} qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) if err != nil { attachFile.close(ctx) - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } - // Construct the filesystem object. - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - attachFile.close(ctx) - client.Close() - return nil, nil, err - } - fs := &filesystem{ - mfp: mfp, - opts: fsopts, - iopts: iopts, - client: client, - clock: ktime.RealtimeClockFromContext(ctx), - devMinor: devMinor, - syncableDentries: make(map[*dentry]struct{}), - specialFileFDs: make(map[*specialFileFD]struct{}), - } - fs.vfsfs.Init(vfsObj, &fstype, fs) - // Construct the root dentry. root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { @@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is deliberately leaked to prevent the root from - // being "cached" and subsequently evicted. Its resources will still be - // cleaned up by fs.Release(). + // caller, and the other is held by fs to prevent the root from being "cached" + // and subsequently evicted. root.refs = 2 + fs.root = root return &fs.vfsfs, &root.vfsd, nil } +func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { + // Check that the transport is "fd". + trans, ok := mopts["trans"] + if !ok || trans != "fd" { + ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'") + return -1, syserror.EINVAL + } + delete(mopts, "trans") + + // Check that read and write FDs are provided and identical. + rfdstr, ok := mopts["rfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "rfdno") + rfd, err := strconv.Atoi(rfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr) + return -1, syserror.EINVAL + } + wfdstr, ok := mopts["wfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "wfdno") + wfd, err := strconv.Atoi(wfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr) + return -1, syserror.EINVAL + } + if rfd != wfd { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd) + return -1, syserror.EINVAL + } + return rfd, nil +} + +// Preconditions: fs.client == nil. +func (fs *filesystem) dial(ctx context.Context) error { + // Establish a connection with the server. + conn, err := unet.NewSocket(fs.opts.fd) + if err != nil { + return err + } + + // Perform version negotiation with the server. + ctx.UninterruptibleSleepStart(false) + client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + conn.Close() + return err + } + // Ownership of conn has been transferred to client. + + fs.client = client + return nil +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { - mf := fs.mfp.MemoryFile() + atomic.StoreInt32(&fs.released, 1) + mf := fs.mfp.MemoryFile() fs.syncMu.Lock() for d := range fs.syncableDentries { d.handleMu.Lock() d.dataMu.Lock() if h := d.writeHandleLocked(); h.isOpen() { // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) } // TODO(jamieliu): Do we need to flushf/fsync d? @@ -539,6 +560,21 @@ func (fs *filesystem) Release(ctx context.Context) { // fs. fs.syncMu.Unlock() + // If leak checking is enabled, release all outstanding references in the + // filesystem. We deliberately avoid doing this outside of leak checking; we + // have released all external resources above rather than relying on dentry + // destructors. + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + fs.renameMu.Lock() + fs.root.releaseSyntheticRecursiveLocked(ctx) + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // An extra reference was held by the filesystem on the root to prevent it from + // being cached/evicted. + fs.root.DecRef(ctx) + } + if !fs.iopts.LeakConnection { // Close the connection to the server. This implicitly clunks all fids. fs.client.Close() @@ -547,6 +583,31 @@ func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } +// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements +// the reference count on every synthetic dentry. Synthetic dentries have one +// reference for existence that should be dropped during filesystem.Release. +// +// Precondition: d.fs.renameMu is locked. +func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { + if d.isSynthetic() { + d.decRefLocked() + d.checkCachingLocked(ctx) + } + if d.isDir() { + var children []*dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + if child != nil { + child.releaseSyntheticRecursiveLocked(ctx) + } + } + } +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -574,12 +635,15 @@ type dentry struct { // filesystem.renameMu. name string + // qidPath is the p9.QID.Path for this file. qidPath is immutable. + qidPath uint64 + // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file // that does not exist on the remote filesystem. As of this writing, the // only files that can be synthetic are sockets, pipes, and directories. - file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + file p9file `state:"nosave"` // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. @@ -623,12 +687,12 @@ type dentry struct { // To mutate: // - Lock metadataMu and use atomic operations to update because we might // have atomic readers that don't hold the lock. - metadataMu sync.Mutex `state:"nosave"` - ino inodeNumber // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + metadataMu sync.Mutex `state:"nosave"` + ino uint64 // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 @@ -679,9 +743,9 @@ type dentry struct { // (isNil() == false), it may be mutated with handleMu locked, but cannot // be closed until the dentry is destroyed. handleMu sync.RWMutex `state:"nosave"` - readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - hostFD int32 + readFile p9file `state:"nosave"` + writeFile p9file `state:"nosave"` + hostFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -758,8 +822,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, + qidPath: qid.Path, file: file, - ino: inoFromPath(qid.Path), + ino: fs.inoFromQIDPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), @@ -795,6 +860,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d.nlink = uint32(attr.NLink) } d.vfsd.Init(d) + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(d, "gofer.dentry") + } fs.syncMu.Lock() fs.syncableDentries[d] = struct{}{} @@ -802,6 +870,21 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma return d, nil } +func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 { + fs.inoMu.Lock() + defer fs.inoMu.Unlock() + if ino, ok := fs.inoByQIDPath[qidPath]; ok { + return ino + } + ino := fs.nextIno() + fs.inoByQIDPath[qidPath] = ino + return ino +} + +func (fs *filesystem) nextIno() uint64 { + return atomic.AddUint64(&fs.lastIno, 1) +} + func (d *dentry) isSynthetic() bool { return d.file.isNil() } @@ -853,7 +936,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } } -// Preconditions: !d.isSynthetic() +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { // Use d.readFile or d.writeFile, which represent 9P fids that have been // opened, in preference to d.file, which represents a 9P fid that has not. @@ -916,10 +999,10 @@ func (d *dentry) statTo(stat *linux.Statx) { // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. stat.Blocks = (stat.Size + 511) / 512 - stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime)) - stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime)) - stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime)) - stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime)) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime)) + stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime)) stat.DevMajor = linux.UNNAMED_MAJOR stat.DevMinor = d.fs.devMinor } @@ -967,10 +1050,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // Use client clocks for timestamps. now = d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW { - stat.Atime = statxTimestampFromDentry(now) + stat.Atime = linux.NsecToStatxTimestamp(now) } if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW { - stat.Mtime = statxTimestampFromDentry(now) + stat.Mtime = linux.NsecToStatxTimestamp(now) } } @@ -1029,11 +1112,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // !d.cachedMetadataAuthoritative() then we returned after calling // d.file.setAttr(). For the same reason, now must have been initialized. if stat.Mask&linux.STATX_ATIME != 0 { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) + atomic.StoreInt64(&d.atime, stat.Atime.ToNsec()) atomic.StoreUint32(&d.atimeDirty, 0) } if stat.Mask&linux.STATX_MTIME != 0 { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) + atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec()) atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) @@ -1175,6 +1258,11 @@ func (d *dentry) decRefLocked() { } } +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { if d.isDir() { @@ -1223,6 +1311,10 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // resolution, which requires renameMu, so if d.refs is zero then it will // remain zero while we hold renameMu for writing.) refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + return + } if refs > 0 { if d.cached { d.fs.cachedDentries.Remove(d) @@ -1231,10 +1323,6 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } return } - if refs == -1 { - // Dentry has already been destroyed. - return - } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { @@ -1257,6 +1345,16 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { if d.watches.Size() > 0 { return } + + if atomic.LoadInt32(&d.fs.released) != 0 { + if d.parent != nil { + d.parent.dirMu.Lock() + delete(d.parent.children, d.name) + d.parent.dirMu.Unlock() + } + d.destroyLocked(ctx) + } + // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) @@ -1269,33 +1367,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentriesLen++ d.cached = true if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { - victim := d.fs.cachedDentries.Back() - d.fs.cachedDentries.Remove(victim) - d.fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path resolution - // since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() - } - victim.destroyLocked(ctx) - } + d.fs.evictCachedDentryLocked(ctx) // Whether or not victim was destroyed, we brought fs.cachedDentriesLen // back down to fs.opts.maxCachedDentries, so we don't loop. } } +// Precondition: fs.renameMu must be locked for writing; it may be temporarily +// unlocked. +func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } +} + +// Preconditions: +// * fs.renameMu must be locked for writing; it may be temporarily unlocked. +// * fs.cachedDentriesLen != 0. +func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + victim := fs.cachedDentries.Back() + fs.cachedDentries.Remove(victim) + fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. + } + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } +} + // destroyLocked destroys the dentry. // // Preconditions: @@ -1380,6 +1493,10 @@ func (d *dentry) destroyLocked(ctx context.Context) { panic("gofer.dentry.DecRef() called without holding a reference") } } + + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(d, "gofer.dentry") + } } func (d *dentry) isDeleted() bool { @@ -1623,6 +1740,33 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { return nil } +func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() + if h.isOpen() { + // Write back dirty pages to the remote file. + d.dataMu.Lock() + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } + } + if err := d.syncRemoteFileLocked(ctx); err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (d is a regular file and was opened for writing). + if d.isRegularFile() && h.isOpen() { + return err + } + ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err) + } + return nil +} + // incLinks increments link count. func (d *dentry) incLinks() { if atomic.LoadUint32(&d.nlink) == 0 { @@ -1650,7 +1794,7 @@ type fileDescription struct { vfs.FileDescriptionDefaultImpl vfs.LockFD - lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + lockLogging sync.Once `state:"nosave"` } func (fd *fileDescription) filesystem() *filesystem { diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index bfe75dfe4..76f08e252 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -26,12 +26,13 @@ import ( func TestDestroyIdempotent(t *testing.T) { ctx := contexttest.Context(t) fs := filesystem{ - mfp: pgalloc.MemoryFileProviderFromContext(ctx), - syncableDentries: make(map[*dentry]struct{}), + mfp: pgalloc.MemoryFileProviderFromContext(ctx), opts: filesystemOptions{ // Test relies on no dentry being held in the cache. maxCachedDentries: 0, }, + syncableDentries: make(map[*dentry]struct{}), + inoByQIDPath: make(map[uint64]uint64), } attr := &p9.Attr{ diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index 7294de7d6..c7bf10007 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error { if ok { return nil } - if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { - return err + if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil { + // Another application thread may have opened this pipe for + // writing, succeeded because we previously opened the pipe for + // reading, and subsequently interrupted us for checkpointing (e.g. + // this occurs in mknod tests under cooperative save/restore). In + // this case, our open has to succeed for the checkpoint to include + // a readable FD for the pipe, which is in turn necessary to + // restore the other thread's writable FD for the same pipe + // (otherwise it will get ENXIO). So we have to check + // nonblockingPipeHasWriter() once last time. + ok, err := nonblockingPipeHasWriter(fd) + if err != nil { + return err + } + if ok { + return nil + } + return sleepErr } } } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index f8b19bae7..dc8a890cb 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -18,7 +18,6 @@ import ( "fmt" "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -624,23 +624,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncCachedFile(ctx) -} - -func (d *dentry) syncCachedFile(ctx context.Context) error { - d.handleMu.RLock() - defer d.handleMu.RUnlock() - - if h := d.writeHandleLocked(); h.isOpen() { - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) - d.dataMu.Unlock() - if err != nil { - return err - } - } - return d.syncRemoteFileLocked(ctx) + return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -913,7 +897,7 @@ type dentryPlatformFile struct { hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + hostFileMapperInitOnce sync.Once `state:"nosave"` } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go new file mode 100644 index 000000000..2ea224c43 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -0,0 +1,329 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "fmt" + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type saveRestoreContextID int + +const ( + // CtxRestoreServerFDMap is a Context.Value key for a map[string]int + // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID) + // to host FDs. + CtxRestoreServerFDMap saveRestoreContextID = iota +) + +// +stateify savable +type savedDentryRW struct { + read bool + write bool +} + +// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave. +func (fs *filesystem) PrepareSave(ctx context.Context) error { + if len(fs.iopts.UniqueID) == 0 { + return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved") + } + + // Purge cached dentries, which may not be reopenable after restore due to + // permission changes. + fs.renameMu.Lock() + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // Buffer pipe data so that it's available for reading after restore. (This + // is a legacy VFS1 feature.) + fs.syncMu.Lock() + for sffd := range fs.specialFileFDs { + if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() { + if err := sffd.savePipeData(ctx); err != nil { + fs.syncMu.Unlock() + return err + } + } + } + fs.syncMu.Unlock() + + // Flush local state to the remote filesystem. + if err := fs.Sync(ctx); err != nil { + return err + } + + fs.savedDentryRW = make(map[*dentry]savedDentryRW) + return fs.root.prepareSaveRecursive(ctx) +} + +// Preconditions: +// * fd represents a pipe. +// * fd is readable. +func (fd *specialFileFD) savePipeData(ctx context.Context) error { + fd.bufMu.Lock() + defer fd.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0)) + if n != 0 { + fd.buf = append(fd.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syserror.EAGAIN { + break + } + return err + } + } + if len(fd.buf) != 0 { + atomic.StoreUint32(&fd.haveBuf, 1) + } + return nil +} + +func (d *dentry) prepareSaveRecursive(ctx context.Context) error { + if d.isRegularFile() && !d.cachedMetadataAuthoritative() { + // Get updated metadata for d in case we need to perform metadata + // validation during restore. + if err := d.updateFromGetattr(ctx); err != nil { + return err + } + } + if !d.readFile.isNil() || !d.writeFile.isNil() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: !d.readFile.isNil(), + write: !d.writeFile.isNil(), + } + } + d.dirMu.Lock() + defer d.dirMu.Unlock() + for _, child := range d.children { + if child != nil { + if err := child.prepareSaveRecursive(ctx); err != nil { + return err + } + } + } + return nil +} + +// beforeSave is invoked by stateify. +func (d *dentry) beforeSave() { + if d.vfsd.IsDead() { + panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d))) + } +} + +// afterLoad is invoked by stateify. +func (d *dentry) afterLoad() { + d.hostFD = -1 + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d, "gofer.dentry") + } +} + +// afterLoad is invoked by stateify. +func (d *dentryPlatformFile) afterLoad() { + if d.hostFileMapper.IsInited() { + // Ensure that we don't call d.hostFileMapper.Init() again. + d.hostFileMapperInitOnce.Do(func() {}) + } +} + +// afterLoad is invoked by stateify. +func (fd *specialFileFD) afterLoad() { + fd.handle.fd = -1 +} + +// CompleteRestore implements +// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore. +func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error { + fdmapv := ctx.Value(CtxRestoreServerFDMap) + if fdmapv == nil { + return fmt.Errorf("no server FD map available") + } + fdmap := fdmapv.(map[string]int) + fd, ok := fdmap[fs.iopts.UniqueID] + if !ok { + return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID) + } + fs.opts.fd = fd + if err := fs.dial(ctx); err != nil { + return err + } + fs.inoByQIDPath = make(map[uint64]uint64) + + // Restore the filesystem root. + ctx.UninterruptibleSleepStart(false) + attached, err := fs.client.Attach(fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return err + } + attachFile := p9file{attached} + qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) + if err != nil { + return err + } + if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { + return err + } + + // Restore remaining dentries. + if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil { + return err + } + + // Re-open handles for specialFileFDs. Unlike the initial open + // (dentry.openSpecialFile()), pipes are always opened without blocking; + // non-readable pipe FDs are opened last to ensure that they don't get + // ENXIO if another specialFileFD represents the read end of the same pipe. + // This is consistent with VFS1. + haveWriteOnlyPipes := false + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + haveWriteOnlyPipes = true + continue + } + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + if haveWriteOnlyPipes { + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + } + } + + // Discard state only required during restore. + fs.savedDentryRW = nil + + return nil +} + +func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error { + d.file = file + + // Gofers do not preserve QID across checkpoint/restore, so: + // + // - We must assume that the remote filesystem did not change in a way that + // would invalidate dentries, since we can't revalidate dentries by + // checking QIDs. + // + // - We need to associate the new QID.Path with the existing d.ino. + d.qidPath = qid.Path + d.fs.inoMu.Lock() + d.fs.inoByQIDPath[qid.Path] = d.ino + d.fs.inoMu.Unlock() + + // Check metadata stability before updating metadata. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + if d.isRegularFile() { + if opts.ValidateFileSizes { + if !attrMask.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d)) + } + if d.size != attr.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size) + } + } + if opts.ValidateFileModificationTimestamps { + if !attrMask.MTime { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d)) + } + if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want)) + } + } + } + if !d.cachedMetadataAuthoritative() { + d.updateFromP9AttrsLocked(attrMask, attr) + } + + if rw, ok := d.fs.savedDentryRW[d]; ok { + if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil { + return err + } + } + + return nil +} + +// Preconditions: d is not synthetic. +func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + for _, child := range d.children { + if child == nil { + continue + } + if _, ok := d.fs.syncableDentries[child]; !ok { + // child is synthetic. + continue + } + if err := child.restoreRecursive(ctx, opts); err != nil { + return err + } + } + return nil +} + +// Preconditions: d is not synthetic (but note that since this function +// restores d.file, d.file.isNil() is always true at this point, so this can +// only be detected by checking filesystem.syncableDentries). d.parent has been +// restored. +func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { + return err + } + return d.restoreDescendantsRecursive(ctx, opts) +} + +func (fd *specialFileFD) completeRestore(ctx context.Context) error { + d := fd.dentry() + h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + if err != nil { + return err + } + fd.handle = h + + ftype := d.fileType() + fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0 + if fd.haveQueue { + if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 326b940a7..a21199eac 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -42,9 +42,6 @@ type endpoint struct { // dentry is the filesystem dentry which produced this endpoint. dentry *dentry - // file is the p9 file that contains a single unopened fid. - file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - // path is the sentry path where this endpoint is bound. path string } @@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect } func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { - hostFile, err := e.file.Connect(flags) + hostFile, err := e.dentry.file.connect(ctx, flags) if err != nil { return nil, syserr.ErrConnectionRefused } @@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) if serr != nil { - log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr) + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr) return nil, serr } return c, nil diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 71581736c..625400c0b 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -15,7 +15,6 @@ package gofer import ( - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -40,7 +40,7 @@ type specialFileFD struct { fileDescription // handle is used for file I/O. handle is immutable. - handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + handle handle `state:"nosave"` // isRegularFile is true if this FD represents a regular file which is only // possible when filesystemOptions.regularFilesUseSpecialFileFD is in @@ -54,12 +54,20 @@ type specialFileFD struct { // haveQueue is true if this file description represents a file for which // queue may send I/O readiness events. haveQueue is immutable. - haveQueue bool + haveQueue bool `state:"nosave"` queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex `state:"nosave"` off int64 + + // If haveBuf is non-zero, this FD represents a pipe, and buf contains data + // read from the pipe from previous calls to specialFileFD.savePipeData(). + // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic + // memory operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte } func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { @@ -87,6 +95,9 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, } return nil, err } + d.fs.syncMu.Lock() + d.fs.specialFileFDs[fd] = struct{}{} + d.fs.syncMu.Unlock() return fd, nil } @@ -161,26 +172,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs return 0, syserror.EOPNOTSUPP } - // Going through dst.CopyOutFrom() holds MM locks around file operations of - // unknown duration. For regularFileFD, doing so is necessary to support - // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't - // hold here since specialFileFD doesn't client-cache data. Just buffer the - // read instead. if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } + + bufN := int64(0) + if atomic.LoadUint32(&fd.haveBuf) != 0 { + var err error + fd.bufMu.Lock() + if len(fd.buf) != 0 { + var n int + n, err = dst.CopyOut(ctx, fd.buf) + dst = dst.DropFirst(n) + fd.buf = fd.buf[n:] + if len(fd.buf) == 0 { + atomic.StoreUint32(&fd.haveBuf, 0) + fd.buf = nil + } + bufN = int64(n) + if offset >= 0 { + offset += bufN + } + } + fd.bufMu.Unlock() + if err != nil { + return bufN, err + } + } + + // Going through dst.CopyOutFrom() would hold MM locks around file + // operations of unknown duration. For regularFileFD, doing so is necessary + // to support mmap due to lock ordering; MM locks precede dentry.dataMu. + // That doesn't hold here since specialFileFD doesn't client-cache data. + // Just buffer the read instead. buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } if n == 0 { - return 0, err + return bufN, err } if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil { - return int64(cp), cperr + return bufN + int64(cp), cperr } - return int64(n), err + return bufN + int64(n), err } // Read implements vfs.FileDescriptionImpl.Read. @@ -217,16 +253,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() - // If the regular file fd was opened with O_APPEND, make sure the file size - // is updated. There is a possible race here if size is modified externally - // after metadata cache is updated. - if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { - return 0, offset, err + if fd.isRegularFile { + // If the regular file fd was opened with O_APPEND, make sure the file + // size is updated. There is a possible race here if size is modified + // externally after metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } - } - if fd.isRegularFile { // We need to hold the metadataMu *while* writing to a regular file. d.metadataMu.Lock() defer d.metadataMu.Unlock() @@ -306,13 +342,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - // If we have a host FD, fsyncing it is likely to be faster than an fsync - // RPC. - if fd.handle.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(fd.handle.fd)) - ctx.UninterruptibleSleepFinish(false) - return err + return fd.sync(ctx, false /* forFilesystemSync */) +} + +func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error { + err := func() error { + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if fd.handle.fd >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(fd.handle.fd)) + ctx.UninterruptibleSleepFinish(false) + return err + } + return fd.handle.file.fsync(ctx) + }() + if err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (fd represents a regular file that was opened for writing). + if fd.isRegularFile && fd.vfsfd.IsWritable() { + return err + } + ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err) } - return fd.handle.file.fsync(ctx) + return nil } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 7e825caae..9cbe805b9 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -17,7 +17,6 @@ package gofer import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 { return int64(s*1e9 + ns) } -func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 { - return ts.Sec*1e9 + int64(ts.Nsec) -} - -func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { - return linux.StatxTimestamp{ - Sec: ns / 1e9, - Nsec: uint32(ns % 1e9), - } -} - // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 56bcf9bdb..dc0f86061 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "inode_refs.go", package = "host", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "connected_endpoint_refs.go", package = "host", prefix = "ConnectedEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ConnectedEndpoint", }, @@ -34,6 +34,7 @@ go_library( "inode_refs.go", "ioctl_unsafe.go", "mmap.go", + "save_restore.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -51,6 +52,7 @@ go_library( "//pkg/log", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go index 0135e4428..13ef48cb5 100644 --- a/pkg/sentry/fsimpl/host/control.go +++ b/pkg/sentry/fsimpl/host/control.go @@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription { } // Create the file backed by hostFD. - file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */) + file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{}) if err != nil { ctx.Warningf("Error creating file from host FD: %v", err) break diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 698e913fe..eeed0f97d 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -19,6 +19,7 @@ package host import ( "fmt" "math" + "sync/atomic" "syscall" "golang.org/x/sys/unix" @@ -40,34 +41,106 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) { - // Determine if hostFD is seekable. If not, this syscall will return ESPIPE - // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character - // devices. +// inode implements kernfs.Inode. +// +// +stateify savable +type inode struct { + kernfs.InodeNoStatFS + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. + + locks vfs.FileLocks + + // When the reference count reaches zero, the host fd is closed. + inodeRefs + + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + hostFD int + + // ino is an inode number unique within this filesystem. + // + // This field is initialized at creation time and is immutable. + ino uint64 + + // ftype is the file's type (a linux.S_IFMT mask). + // + // This field is initialized at creation time and is immutable. + ftype uint16 + + // mayBlock is true if hostFD is non-blocking, and operations on it may + // return EAGAIN or EWOULDBLOCK instead of blocking. + // + // This field is initialized at creation time and is immutable. + mayBlock bool + + // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file + // offsets are meaningful iff seekable is true. + // + // This field is initialized at creation time and is immutable. + seekable bool + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool + + // savable is true if hostFD may be saved/restored by its numeric value. + // + // This field is initialized at creation time and is immutable. + savable bool + + // Event queue for blocking operations. + queue waiter.Queue + + // mapsMu protects mappings. + mapsMu sync.Mutex `state:"nosave"` + + // If this file is mmappable, mappings tracks mappings of hostFD into + // memmap.MappingSpaces. + mappings memmap.MappingSet + + // pf implements platform.File for mappings of hostFD. + pf inodePlatformFile + + // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data + // read from the pipe from previous calls to inode.beforeSave(). haveBuf + // and buf are protected by bufMu. haveBuf is accessed using atomic memory + // operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte +} + +func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) { + // Determine if hostFD is seekable. _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) seekable := err != syserror.ESPIPE + // We expect regular files to be seekable, as this is required for them to + // be memory-mappable. + if !seekable && fileType == syscall.S_IFREG { + ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD) + return nil, syserror.ESPIPE + } i := &inode{ - hostFD: hostFD, - ino: fs.NextIno(), - isTTY: isTTY, - wouldBlock: wouldBlock(uint32(fileType)), - seekable: seekable, - // NOTE(b/38213152): Technically, some obscure char devices can be memory - // mapped, but we only allow regular files. - canMap: fileType == linux.S_IFREG, + hostFD: hostFD, + ino: fs.NextIno(), + ftype: uint16(fileType), + mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR, + seekable: seekable, + isTTY: isTTY, + savable: savable, } i.pf.inode = i i.EnableLeakCheck() - // Non-seekable files can't be memory mapped, assert this. - if !i.seekable && i.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - // If the hostFD would block, we must set it to non-blocking and handle - // blocking behavior in the sentry. - if i.wouldBlock { + // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and + // handle blocking behavior in the sentry. + if i.mayBlock { if err := syscall.SetNonblock(i.hostFD, true); err != nil { return nil, err } @@ -80,6 +153,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) ( // NewFDOptions contains options to NewFD. type NewFDOptions struct { + // If Savable is true, the host file descriptor may be saved/restored by + // numeric value; the sandbox API requires a corresponding host FD with the + // same numeric value to be provieded at time of restore. + Savable bool + // If IsTTY is true, the file descriptor is a TTY. IsTTY bool @@ -114,7 +192,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) } d := &kernfs.Dentry{} - i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY) + i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY) if err != nil { return nil, err } @@ -132,7 +210,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // ImportFD sets up and returns a vfs.FileDescription from a donated fd. func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) { return NewFD(ctx, mnt, hostFD, &NewFDOptions{ - IsTTY: isTTY, + Savable: true, + IsTTY: isTTY, }) } @@ -191,68 +270,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } -// inode implements kernfs.Inode. -// -// +stateify savable -type inode struct { - kernfs.InodeNoStatFS - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. - - locks vfs.FileLocks - - // When the reference count reaches zero, the host fd is closed. - inodeRefs - - // hostFD contains the host fd that this file was originally created from, - // which must be available at time of restore. - // - // This field is initialized at creation time and is immutable. - hostFD int - - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 - - // isTTY is true if this file represents a TTY. - // - // This field is initialized at creation time and is immutable. - isTTY bool - - // seekable is false if the host fd points to a file representing a stream, - // e.g. a socket or a pipe. Such files are not seekable and can return - // EWOULDBLOCK for I/O operations. - // - // This field is initialized at creation time and is immutable. - seekable bool - - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. - // - // This field is initialized at creation time and is immutable. - wouldBlock bool - - // Event queue for blocking operations. - queue waiter.Queue - - // canMap specifies whether we allow the file to be memory mapped. - // - // This field is initialized at creation time and is immutable. - canMap bool - - // mapsMu protects mappings. - mapsMu sync.Mutex `state:"nosave"` - - // If canMap is true, mappings tracks mappings of hostFD into - // memmap.MappingSpaces. - mappings memmap.MappingSet - - // pf implements platform.File for mappings of hostFD. - pf inodePlatformFile -} - // CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t @@ -448,7 +465,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre // DecRef implements kernfs.Inode.DecRef. func (i *inode) DecRef(ctx context.Context) { i.inodeRefs.DecRef(func() { - if i.wouldBlock { + if i.mayBlock { fdnotifier.RemoveFD(int32(i.hostFD)) } if err := unix.Close(i.hostFD); err != nil { @@ -567,6 +584,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin // PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { return 0, syserror.ESPIPE @@ -577,19 +601,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off // Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { + bufN, err := i.readFromBuf(ctx, &dst) + if err != nil { + return bufN, err + } n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags) + total := bufN + n if isBlockError(err) { // If we got any data at all, return it as a "completed" partial read // rather than retrying until complete. - if n != 0 { + if total != 0 { err = nil } else { err = syserror.ErrWouldBlock } } - return n, err + return total, err } f.offsetMu.Lock() @@ -599,13 +635,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } -func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // Check that flags are supported. - // - // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. - if flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP +func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) { + if atomic.LoadUint32(&i.haveBuf) == 0 { + return 0, nil } + i.bufMu.Lock() + defer i.bufMu.Unlock() + if len(i.buf) == 0 { + return 0, nil + } + n, err := dst.CopyOut(ctx, i.buf) + *dst = dst.DropFirst(n) + i.buf = i.buf[n:] + if len(i.buf) == 0 { + atomic.StoreUint32(&i.haveBuf, 0) + i.buf = nil + } + return int64(n), err +} + +func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) n, err := dst.CopyOutFrom(ctx, reader) hostfd.PutReadWriterAt(reader) @@ -735,14 +784,16 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i } // Sync implements vfs.FileDescriptionImpl.Sync. -func (f *fileDescription) Sync(context.Context) error { +func (f *fileDescription) Sync(ctx context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { - if !f.inode.canMap { + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + if f.inode.ftype != syscall.S_IFREG { return syserror.ENODEV } i := f.inode @@ -753,13 +804,17 @@ func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts // EventRegister implements waiter.Waitable.EventRegister. func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { f.inode.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // EventUnregister implements waiter.Waitable.EventUnregister. func (f *fileDescription) EventUnregister(e *waiter.Entry) { f.inode.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // Readiness uses the poll() syscall to check the status of the underlying FD. diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index b51a17bed..3d7eb2f96 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -43,7 +43,7 @@ type inodePlatformFile struct { fileMapper fsutil.HostFileMapper // fileMapperInitOnce is used to lazily initialize fileMapper. - fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + fileMapperInitOnce sync.Once `state:"nosave"` } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go new file mode 100644 index 000000000..7e32a8863 --- /dev/null +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -0,0 +1,78 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "fmt" + "io" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/usermem" +) + +// beforeSave is invoked by stateify. +func (i *inode) beforeSave() { + if !i.savable { + panic("host.inode is not savable") + } + if i.ftype == syscall.S_IFIFO { + // If this pipe FD is readable, drain it so that bytes in the pipe can + // be read after restore. (This is a legacy VFS1 feature.) We don't + // know if the pipe FD is readable, so just try reading and tolerate + // EBADF from the read. + i.bufMu.Lock() + defer i.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */) + if n != 0 { + i.buf = append(i.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF { + break + } + panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err)) + } + } + if len(i.buf) != 0 { + atomic.StoreUint32(&i.haveBuf, 1) + } + } +} + +// afterLoad is invoked by stateify. +func (i *inode) afterLoad() { + if i.mayBlock { + if err := syscall.SetNonblock(i.hostFD, true); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err)) + } + if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err)) + } + } +} + +// afterLoad is invoked by stateify. +func (i *inodePlatformFile) afterLoad() { + if i.fileMapper.IsInited() { + // Ensure that we don't call i.fileMapper.Init() again. + i.fileMapperInitOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 412bdb2eb..b2f43a119 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)} } -// wouldBlock returns true for file types that can return EWOULDBLOCK -// for blocking operations, e.g. pipes, character devices, and sockets. -func wouldBlock(fileType uint32) bool { - return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK -} - // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 858cc24ce..aaad67ab8 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) go_template_instance( + name = "dentry_list", + out = "dentry_list.go", + package = "kernfs", + prefix = "dentry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Dentry", + "Linker": "*Dentry", + }, +) + +go_template_instance( name = "fstree", out = "fstree.go", package = "kernfs", @@ -27,22 +39,11 @@ go_template_instance( ) go_template_instance( - name = "dentry_refs", - out = "dentry_refs.go", - package = "kernfs", - prefix = "Dentry", - template = "//pkg/refs_vfs2:refs_template", - types = { - "T": "Dentry", - }, -) - -go_template_instance( name = "static_directory_refs", out = "static_directory_refs.go", package = "kernfs", prefix = "StaticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "StaticDirectory", }, @@ -53,7 +54,7 @@ go_template_instance( out = "dir_refs.go", package = "kernfs_test", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -64,7 +65,7 @@ go_template_instance( out = "readonly_dir_refs.go", package = "kernfs_test", prefix = "readonlyDir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "readonlyDir", }, @@ -75,7 +76,7 @@ go_template_instance( out = "synthetic_directory_refs.go", package = "kernfs", prefix = "syntheticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "syntheticDirectory", }, @@ -84,7 +85,7 @@ go_template_instance( go_library( name = "kernfs", srcs = [ - "dentry_refs.go", + "dentry_list.go", "dynamic_bytes_file.go", "fd_impl_util.go", "filesystem.go", @@ -104,8 +105,11 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", @@ -129,6 +133,7 @@ go_test( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index b929118b1..485504995 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -47,11 +47,11 @@ type DynamicBytesFile struct { var _ Inode = (*DynamicBytesFile)(nil) // Init initializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { +func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) f.data = data } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index abf1905d6..f8dae22f8 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { return fd.vfsfd.VirtualDentry().Mount().Filesystem() } +func (fd *GenericDirectoryFD) dentry() *Dentry { + return fd.vfsfd.Dentry().Impl().(*Dentry) +} + func (fd *GenericDirectoryFD) inode() Inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return fd.dentry().inode } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds @@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // Handle "..". if fd.off == 1 { - vfsd := fd.vfsfd.VirtualDentry().Dentry() - parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode + parentInode := genericParentOrSelf(fd.dentry()).inode stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err @@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent var err error relOffset := fd.off - int64(len(fd.children.set)) - 2 - fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset) + fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset) return err } @@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(ctx, fd.filesystem(), creds, opts) + return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts) } // Allocate implements vfs.FileDescriptionImpl.Allocate. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 6426a55f6..399895f3e 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -373,7 +373,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err } - childI = newSyntheticDirectory(rp.Credentials(), opts.Mode) + childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode) } var child Dentry child.Init(fs, childI) @@ -517,9 +517,6 @@ afterTrailingSymlink: } var child Dentry child.Init(fs, childI) - // FIXME(gvisor.dev/issue/1193): Race between checking existence with - // fs.stepExistingLocked and parent.insertChild. If possible, we should hold - // dirMu from one to the other. parent.insertChild(pc, &child) // Open may block so we need to unlock fs.mu. IncRef child to prevent // its destruction while fs.mu is unlocked. diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 122b10591..d9d76758a 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -21,9 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // InodeNoopRefCount partially implements the Inode interface, specifically the @@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error) } // IterDirents implements Inode.IterDirents. -func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { panic("IterDirents called on non-directory inode") } @@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // // +stateify savable type InodeAttrs struct { - devMajor uint32 - devMinor uint32 - ino uint64 - mode uint32 - uid uint32 - gid uint32 - nlink uint32 + devMajor uint32 + devMinor uint32 + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 + blockSize uint32 + + // Timestamps, all nsecs from the Unix epoch. + atime int64 + mtime int64 + ctime int64 } // Init initializes this InodeAttrs. -func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { +func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { if mode.FileType() == 0 { panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) } @@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) atomic.StoreUint32(&a.nlink, nlink) + atomic.StoreUint32(&a.blockSize, usermem.PageSize) + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.atime, now) + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) } // DevMajor returns the device major number. @@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode { return linux.FileMode(atomic.LoadUint32(&a.mode)) } +// TouchAtime updates a.atime to the current time. +func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) { + if mnt.Flags.NoATime || mnt.ReadOnly() { + return + } + if err := mnt.CheckBeginWrite(); err != nil { + return + } + atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds()) + mnt.EndWrite() +} + +// TouchCMtime updates a.{c/m}time to the current time. The caller should +// synchronize calls to this so that ctime and mtime are updated to the same +// value. +func (a *InodeAttrs) TouchCMtime(ctx context.Context) { + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) +} + // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME stat.DevMajor = a.devMajor stat.DevMinor = a.devMinor stat.Ino = atomic.LoadUint64(&a.ino) @@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li stat.UID = atomic.LoadUint32(&a.uid) stat.GID = atomic.LoadUint32(&a.gid) stat.Nlink = atomic.LoadUint32(&a.nlink) - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - + stat.Blksize = atomic.LoadUint32(&a.blockSize) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime)) return stat, nil } // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - return a.SetInodeStat(ctx, fs, creds, opts) -} - -// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. -// This function can be used by other kernfs-based filesystem implementation to -// sets the unexported attributes into InodeAttrs. -func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } @@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds // inode numbers are immutable after node creation. Setting the size is often // allowed by kernfs files but does not do anything. If some other behavior is // needed, the embedder should consider extending SetStat. - // - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { @@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds atomic.StoreUint32(&a.gid, stat.GID) } + now := ktime.NowFromContext(ctx).Nanoseconds() + if stat.Mask&linux.STATX_ATIME != 0 { + if stat.Atime.Nsec == linux.UTIME_NOW { + stat.Atime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.atime, stat.Atime.ToNsec()) + } + if stat.Mask&linux.STATX_MTIME != 0 { + if stat.Mtime.Nsec == linux.UTIME_NOW { + stat.Mtime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec()) + } + return nil } @@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error } // IterDirents implements Inode.IterDirents. -func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { // All entries from OrderedChildren have already been handled in // GenericDirectoryFD.IterDirents. return offset, nil @@ -619,9 +659,9 @@ type StaticDirectory struct { var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { +func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { inode := &StaticDirectory{} - inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) + inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts) inode.EnableLeakCheck() inode.OrderedChildren.Init(OrderedChildrenOptions{}) @@ -632,12 +672,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { +func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } s.fdOpts = fdOpts - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } // Open implements Inode.Open. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 606081e68..5c5e09ac5 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -107,6 +107,17 @@ type Filesystem struct { // nextInoMinusOne is used to to allocate inode numbers on this // filesystem. Must be accessed by atomic operations. nextInoMinusOne uint64 + + // cachedDentries contains all dentries with 0 references. (Due to race + // conditions, it may also contain dentries with non-zero references.) + // cachedDentriesLen is the number of dentries in cachedDentries. These + // fields are protected by mu. + cachedDentries dentryList + cachedDentriesLen uint64 + + // MaxCachedDentries is the maximum size of cachedDentries. If not set, + // defaults to 0 and kernfs does not cache any dentries. This is immutable. + MaxCachedDentries uint64 } // deferDecRef defers dropping a dentry ref until the next call to @@ -165,7 +176,12 @@ const ( // +stateify savable type Dentry struct { vfsd vfs.Dentry - DentryRefs + + // refs is the reference count. When refs reaches 0, the dentry may be + // added to the cache or destroyed. If refs == -1, the dentry has already + // been destroyed. refs are allowed to go to 0 and increase again. refs is + // accessed using atomic memory operations. + refs int64 // fs is the owning filesystem. fs is immutable. fs *Filesystem @@ -177,6 +193,12 @@ type Dentry struct { parent *Dentry name string + // If cached is true, dentryEntry links dentry into + // Filesystem.cachedDentries. cached and dentryEntry are protected by + // Filesystem.mu. + cached bool + dentryEntry + // dirMu protects children and the names of child Dentries. // // Note that holding fs.mu for writing is not sufficient; @@ -188,6 +210,150 @@ type Dentry struct { inode Inode } +// IncRef implements vfs.DentryImpl.IncRef. +func (d *Dentry) IncRef() { + // d.refs may be 0 if d.fs.mu is locked, which serializes against + // d.cacheLocked(). + atomic.AddInt64(&d.refs, 1) +} + +// TryIncRef implements vfs.DentryImpl.TryIncRef. +func (d *Dentry) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&d.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + return true + } + } +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef(ctx context.Context) { + if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + d.fs.mu.Lock() + d.cacheLocked(ctx) + d.fs.mu.Unlock() + } else if refs < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +// cacheLocked should be called after d's reference count becomes 0. The ref +// count check may happen before acquiring d.fs.mu so there might be a race +// condition where the ref count is increased again by the time the caller +// acquires d.fs.mu. This race is handled. +// Only reachable dentries are added to the cache. However, a dentry might +// become unreachable *while* it is in the cache due to invalidation. +// +// Preconditions: d.fs.mu must be locked for writing. +func (d *Dentry) cacheLocked(ctx context.Context) { + // Dentries with a non-zero reference count must be retained. (The only way + // to obtain a reference on a dentry with zero references is via path + // resolution, which requires d.fs.mu, so if d.refs is zero then it will + // remain zero while we hold d.fs.mu for writing.) + refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d)) + } + if refs > 0 { + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + return + } + // If the dentry is deleted and invalidated or has no parent, then it is no + // longer reachable by path resolution and should be dropped immediately + // because it has zero references. + // Note that a dentry may not always have a parent; for example magic links + // as described in Inode.Getlink. + if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil { + if !isDead { + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) + } + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + d.destroyLocked(ctx) + return + } + // If d is already cached, just move it to the front of the LRU. + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentries.PushFront(d) + return + } + // Cache the dentry, then evict the least recently used cached dentry if + // the cache becomes over-full. + d.fs.cachedDentries.PushFront(d) + d.fs.cachedDentriesLen++ + d.cached = true + if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries { + return + } + // Evict the least recently used dentry because cache size is greater than + // max cache size (configured on mount). + victim := d.fs.cachedDentries.Back() + d.fs.cachedDentries.Remove(victim) + d.fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // after it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if !victim.vfsd.IsDead() { + victim.parent.dirMu.Lock() + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry()) + delete(victim.parent.children, victim.name) + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.MaxCachedDentries, so we don't loop. +} + +// destroyLocked destroys the dentry. +// +// Preconditions: +// * d.fs.mu must be locked for writing. +// * d.refs == 0. +// * d should have been removed from d.parent.children, i.e. d is not reachable +// by path traversal. +// * d.vfsd.IsDead() is true. +func (d *Dentry) destroyLocked(ctx context.Context) { + switch atomic.LoadInt64(&d.refs) { + case 0: + // Mark the dentry destroyed. + atomic.StoreInt64(&d.refs, -1) + case -1: + panic("dentry.destroyLocked() called on already destroyed dentry") + default: + panic("dentry.destroyLocked() called with references on the dentry") + } + + d.inode.DecRef(ctx) // IncRef from Init. + d.inode = nil + + // Drop the reference held by d on its parent without recursively locking + // d.fs.mu. + if d.parent != nil { + if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { + d.parent.cacheLocked(ctx) + } else if refs < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } + } +} + // Init initializes this dentry. // // Precondition: Caller must hold a reference on inode. @@ -197,6 +363,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { d.vfsd.Init(d) d.fs = fs d.inode = inode + atomic.StoreInt64(&d.refs, 1) ftype := inode.Mode().FileType() if ftype == linux.ModeDirectory { d.flags |= dflagsIsDir @@ -204,7 +371,6 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { if ftype == linux.ModeSymlink { d.flags |= dflagsIsSymlink } - d.EnableLeakCheck() } // VFSDentry returns the generic vfs dentry for this kernfs dentry. @@ -222,32 +388,6 @@ func (d *Dentry) isSymlink() bool { return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 } -// DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef(ctx context.Context) { - decRefParent := false - d.fs.mu.Lock() - d.DentryRefs.DecRef(func() { - d.inode.DecRef(ctx) // IncRef from Init. - d.inode = nil - if d.parent != nil { - // We will DecRef d.parent once all locks are dropped. - decRefParent = true - d.parent.dirMu.Lock() - // Remove d from parent.children. It might already have been - // removed due to invalidation. - if _, ok := d.parent.children[d.name]; ok { - delete(d.parent.children, d.name) - d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) - } - d.parent.dirMu.Unlock() - } - }) - d.fs.mu.Unlock() - if decRefParent { - d.parent.DecRef(ctx) // IncRef from Dentry.insertChild. - } -} - // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // // Although Linux technically supports inotify on pseudo filesystems (inotify @@ -267,7 +407,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {} // this dentry. This does not update the directory inode, so calling this on its // own isn't sufficient to insert a child into a directory. // -// Precondition: d must represent a directory inode. +// Preconditions: +// * d must represent a directory inode. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChild(name string, child *Dentry) { d.dirMu.Lock() d.insertChildLocked(name, child) @@ -280,6 +422,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) { // Preconditions: // * d must represent a directory inode. // * d.dirMu must be locked. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChildLocked(name string, child *Dentry) { if !d.isDir() { panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d)) @@ -436,7 +579,7 @@ type inodeDirectory interface { // the inode is a directory. // // The child returned by Lookup will be hashed into the VFS dentry tree, - // atleast for the duration of the current FS operation. + // at least for the duration of the current FS operation. // // Lookup must return the child with an extra reference whose ownership is // transferred to the dentry that is created to point to that inode. If @@ -454,7 +597,7 @@ type inodeDirectory interface { // inside the entries returned by this IterDirents invocation. In other words, // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as // the return value, while 'relOffset' is the place to start iteration. - IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) + IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } type inodeSymlink interface { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 82fa19c03..2418eec44 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file." // RootDentryFn is a generator function for creating the root dentry of a test // filesystem. See newTestSystem. -type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode +type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode // newTestSystem sets up a minimal environment for running a test, including an // instance of a test filesystem. Tests can control the contents of the @@ -72,10 +72,10 @@ type file struct { content string } -func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode { +func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode { f := &file{} f.content = content - f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) + f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) return f } @@ -105,9 +105,9 @@ type readonlyDir struct { locks vfs.FileLocks } -func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &readonlyDir{} - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) dir.EnableLeakCheck() dir.IncLinks(dir.OrderedChildren.Populate(contents)) @@ -142,10 +142,10 @@ type dir struct { fs *filesystem } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &dir{} dir.fs = fs - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) dir.EnableLeakCheck() @@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) { func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - dir := d.fs.newDir(creds, opts.Mode, nil) + dir := d.fs.newDir(ctx, creds, opts.Mode, nil) if err := d.OrderedChildren.Insert(name, dir); err != nil { dir.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) d.IncLinks(1) return dir, nil } func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - f := d.fs.newFile(creds, "") + f := d.fs.newFile(ctx, creds, "") if err := d.OrderedChildren.Insert(name, f); err != nil { f.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) return f, nil } @@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {} func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { fs := &filesystem{} fs.VFSFilesystem().Init(vfsObj, &fst, fs) - root := fst.rootFn(creds, fs) + root := fst.rootFn(ctx, creds, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, root) return fs.VFSFilesystem(), d.VFSDentry(), nil @@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst // -------------------- Remainder of the file are test cases -------------------- func TestBasic(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -228,9 +230,9 @@ func TestBasic(t *testing.T) { } func TestMkdirGetDentry(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) { } func TestReadStaticFile(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) { } func TestCreateNewFileInStaticDir(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) { } func TestDirFDReadWrite(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, nil) + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, nil) }) defer sys.Destroy() @@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) { } func TestDirFDIterDirents(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ // Fill root with nodes backed by various inode implementations. - "dir1": fs.newReadonlyDir(creds, 0755, nil), - "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{ - "dir3": fs.newDir(creds, 0755, nil), + "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil), + "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir3": fs.newDir(ctx, creds, 0755, nil), }), - "file1": fs.newFile(creds, staticFileContent), + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 934cc6c9e..a0736c0d6 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -38,16 +38,16 @@ type StaticSymlink struct { var _ Inode = (*StaticSymlink)(nil) // NewStaticSymlink creates a new symlink file pointing to 'target'. -func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { +func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { inode := &StaticSymlink{} - inode.Init(creds, devMajor, devMinor, ino, target) + inode.Init(ctx, creds, devMajor, devMinor, ino, target) return inode } // Init initializes the instance. -func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { +func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { s.target = target - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } // Readlink implements Inode.Readlink. diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go index d0ed17b18..463d77d79 100644 --- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -41,17 +41,17 @@ type syntheticDirectory struct { var _ Inode = (*syntheticDirectory)(nil) -func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode { +func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode { inode := &syntheticDirectory{} - inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) return inode } -func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) } - dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) dir.OrderedChildren.Init(OrderedChildrenOptions{ Writable: true, }) @@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs if !opts.ForSyntheticMountpoint { return nil, syserror.EPERM } - subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) if err := dir.OrderedChildren.Insert(name, subdirI); err != nil { subdirI.DecRef(ctx) return nil, err } + dir.TouchCMtime(ctx) return subdirI, nil } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 8cf5b35d3..fd6c55921 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -21,14 +21,18 @@ go_library( "directory.go", "filesystem.go", "fstree.go", - "non_directory.go", "overlay.go", + "regular_file.go", + "save_restore.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/log", + "//pkg/refsvfs2", + "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", @@ -37,5 +41,6 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/usermem", + "//pkg/waiter", ], ) diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 73b126669..4506642ca 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -75,8 +75,21 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return syserror.ENOENT } - // Perform copy-up. + // Obtain settable timestamps from the lower layer. vfsObj := d.fs.vfsfs.VirtualFilesystem() + oldpop := vfs.PathOperation{ + Root: d.lowerVDs[0], + Start: d.lowerVDs[0], + } + const timestampsMask = linux.STATX_ATIME | linux.STATX_MTIME + oldStat, err := vfsObj.StatAt(ctx, d.fs.creds, &oldpop, &vfs.StatOptions{ + Mask: timestampsMask, + }) + if err != nil { + return err + } + + // Perform copy-up. newpop := vfs.PathOperation{ Root: d.parent.upperVD, Start: d.parent.upperVD, @@ -101,10 +114,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } switch ftype { case linux.S_IFREG: - oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.lowerVDs[0], - Start: d.lowerVDs[0], - }, &vfs.OpenOptions{ + oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &oldpop, &vfs.OpenOptions{ Flags: linux.O_RDONLY, }) if err != nil { @@ -160,9 +170,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } if err := newFD.SetStat(ctx, vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() @@ -179,9 +191,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() @@ -195,10 +209,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { d.upperVD = upperVD case linux.S_IFLNK: - target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.lowerVDs[0], - Start: d.lowerVDs[0], - }) + target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &oldpop) if err != nil { return err } @@ -207,10 +218,12 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID, - Mode: uint16(d.mode), - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + Mode: uint16(d.mode), + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() @@ -224,25 +237,20 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { d.upperVD = upperVD case linux.S_IFBLK, linux.S_IFCHR: - lowerStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.lowerVDs[0], - Start: d.lowerVDs[0], - }, &vfs.StatOptions{}) - if err != nil { - return err - } if err := vfsObj.MknodAt(ctx, d.fs.creds, &newpop, &vfs.MknodOptions{ Mode: linux.FileMode(d.mode), - DevMajor: lowerStat.RdevMajor, - DevMinor: lowerStat.RdevMinor, + DevMajor: oldStat.RdevMajor, + DevMinor: oldStat.RdevMinor, }); err != nil { return err } if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index bd11372d5..10161a08d 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -302,8 +302,14 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str child.devMinor = fs.dirDevMinor child.ino = fs.newDirIno() } else if !child.upperVD.Ok() { + childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor) + if err != nil { + ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err) + child.destroyLocked(ctx) + return nil, err + } child.devMajor = linux.UNNAMED_MAJOR - child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()] + child.devMinor = childDevMinor } parent.IncRef() @@ -765,7 +771,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - if mayWrite { + if start.isRegularFile() && mayWrite { if err := start.copyUpLocked(ctx); err != nil { return nil, err } @@ -819,7 +825,7 @@ afterTrailingSymlink: if rp.MustBeDir() && !child.isDir() { return nil, syserror.ENOTDIR } - if mayWrite { + if child.isRegularFile() && mayWrite { if err := child.copyUpLocked(ctx); err != nil { return nil, err } @@ -872,8 +878,11 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * if err != nil { return nil, err } + if ftype != linux.S_IFREG { + return layerFD, nil + } layerFlags := layerFD.StatusFlags() - fd := &nonDirectoryFD{ + fd := ®ularFileFD{ copiedUp: isUpper, cachedFD: layerFD, cachedFlags: layerFlags, @@ -969,7 +978,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } // Finally construct the overlay FD. upperFlags := upperFD.StatusFlags() - fd := &nonDirectoryFD{ + fd := ®ularFileFD{ copiedUp: true, cachedFD: upperFD, cachedFlags: upperFlags, @@ -1293,6 +1302,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if !child.isDir() { return syserror.ENOTDIR } + if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(atomic.LoadUint32(&parent.mode)), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil { + return err + } child.dirMu.Lock() defer child.dirMu.Unlock() whiteouts, err := child.collectWhiteoutsForRmdirLocked(ctx) @@ -1528,12 +1540,38 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } + parentMode := atomic.LoadUint32(&parent.mode) child := parent.children[name] var childLayer lookupLayer + if child == nil { + if parentMode&linux.S_ISVTX != 0 { + // If the parent's sticky bit is set, we need a child dentry to get + // its owner. + child, err = fs.getChildLocked(ctx, parent, name, &ds) + if err != nil { + return err + } + } else { + // Determine if the file being unlinked actually exists. Holding + // parent.dirMu prevents a dentry from being instantiated for the file, + // which in turn prevents it from being copied-up, so this result is + // stable. + childLayer, err = fs.lookupLayerLocked(ctx, parent, name) + if err != nil { + return err + } + if !childLayer.existsInOverlay() { + return syserror.ENOENT + } + } + } if child != nil { if child.isDir() { return syserror.EISDIR } + if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(parentMode), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil { + return err + } if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } @@ -1546,18 +1584,6 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } else { childLayer = lookupLayerLower } - } else { - // Determine if the file being unlinked actually exists. Holding - // parent.dirMu prevents a dentry from being instantiated for the file, - // which in turn prevents it from being copied-up, so this result is - // stable. - childLayer, err = fs.lookupLayerLocked(ctx, parent, name) - if err != nil { - return err - } - if !childLayer.existsInOverlay() { - return syserror.ENOENT - } } pop := vfs.PathOperation{ diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index e5f506d2e..c812f0a70 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -18,10 +18,11 @@ // // Lock order: // -// directoryFD.mu / nonDirectoryFD.mu +// directoryFD.mu / regularFileFD.mu // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// filesystem.devMu // *** "memmap.Mappable locks" below this point // dentry.mapsMu // *** "memmap.Mappable locks taken by Translate" below this point @@ -33,12 +34,14 @@ package overlay import ( + "fmt" "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -99,10 +102,15 @@ type filesystem struct { // is immutable. dirDevMinor uint32 - // lowerDevMinors maps lower layer filesystems to device minor numbers - // assigned to non-directory files originating from that filesystem. - // lowerDevMinors is immutable. - lowerDevMinors map[*vfs.Filesystem]uint32 + // lowerDevMinors maps device numbers from lower layer filesystems to + // device minor numbers assigned to non-directory files originating from + // that filesystem. (This remapping is necessary for lower layers because a + // file on a lower layer, and that same file on an overlay, are + // distinguishable because they will diverge after copy-up; this isn't true + // for non-directory files already on the upper layer.) lowerDevMinors is + // protected by devMu. + devMu sync.Mutex `state:"nosave"` + lowerDevMinors map[layerDevNumber]uint32 // renameMu synchronizes renaming with non-renaming operations in order to // ensure consistent lock ordering between dentry.dirMu in different @@ -114,78 +122,69 @@ type filesystem struct { lastDirIno uint64 } +// +stateify savable +type layerDevNumber struct { + major uint32 + minor uint32 +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { mopts := vfs.GenericParseMountOptions(opts.Data) fsoptsRaw := opts.InternalData - fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) - if fsoptsRaw != nil && !haveFSOpts { + fsopts, ok := fsoptsRaw.(FilesystemOptions) + if fsoptsRaw != nil && !ok { ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } - if haveFSOpts { - if len(fsopts.LowerRoots) == 0 { - ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + + if upperPathname, ok := mopts["upperdir"]; ok { + if fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified") return nil, nil, syserror.EINVAL } - if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + delete(mopts, "upperdir") + // Linux overlayfs also requires a workdir when upperdir is + // specified; we don't, so silently ignore this option. + delete(mopts, "workdir") + upperPath := fspath.Parse(upperPathname) + if !upperPath.Absolute { + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } - // We don't enforce a maximum number of lower layers when not - // configured by applications; the sandbox owner can have an overlay - // filesystem with any number of lower layers. - } else { - vfsroot := vfs.RootFromContext(ctx) - defer vfsroot.DecRef(ctx) - upperPathname, ok := mopts["upperdir"] - if ok { - delete(mopts, "upperdir") - // Linux overlayfs also requires a workdir when upperdir is - // specified; we don't, so silently ignore this option. - delete(mopts, "workdir") - upperPath := fspath.Parse(upperPathname) - if !upperPath.Absolute { - ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) - return nil, nil, syserror.EINVAL - } - upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ - Root: vfsroot, - Start: vfsroot, - Path: upperPath, - FollowFinalSymlink: true, - }, &vfs.GetDentryOptions{ - CheckSearchable: true, - }) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer upperRoot.DecRef(ctx) - privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer privateUpperRoot.DecRef(ctx) - fsopts.UpperRoot = privateUpperRoot + upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: upperPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + return nil, nil, err + } + privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) + upperRoot.DecRef(ctx) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + return nil, nil, err } - lowerPathnamesStr, ok := mopts["lowerdir"] - if !ok { - ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + defer privateUpperRoot.DecRef(ctx) + fsopts.UpperRoot = privateUpperRoot + } + + if lowerPathnamesStr, ok := mopts["lowerdir"]; ok { + if len(fsopts.LowerRoots) != 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") - const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK - if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") - return nil, nil, syserror.EINVAL - } - if len(lowerPathnames) > maxLowerLayers { - ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) - return nil, nil, syserror.EINVAL - } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { @@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) + lowerRoot.DecRef(ctx) if err != nil { ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err @@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot) } } + if len(mopts) != 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } - // Allocate device numbers. + if len(fsopts.LowerRoots) == 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required") + return nil, nil, syserror.EINVAL + } + if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present") + return nil, nil, syserror.EINVAL + } + const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK + if len(fsopts.LowerRoots) > maxLowerLayers { + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers) + return nil, nil, syserror.EINVAL + } + + // Allocate dirDevMinor. lowerDevMinors are allocated dynamically. dirDevMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } - lowerDevMinors := make(map[*vfs.Filesystem]uint32) - for _, lowerRoot := range fsopts.LowerRoots { - lowerFS := lowerRoot.Mount().Filesystem() - if _, ok := lowerDevMinors[lowerFS]; !ok { - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - vfsObj.PutAnonBlockDevMinor(dirDevMinor) - for _, lowerDevMinor := range lowerDevMinors { - vfsObj.PutAnonBlockDevMinor(lowerDevMinor) - } - return nil, nil, err - } - lowerDevMinors[lowerFS] = devMinor - } - } // Take extra references held by the filesystem. if fsopts.UpperRoot.Ok() { @@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt opts: fsopts, creds: creds.Fork(), dirDevMinor: dirDevMinor, - lowerDevMinors: lowerDevMinors, + lowerDevMinors: make(map[layerDevNumber]uint32), } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt root.ino = fs.newDirIno() } else if !root.upperVD.Ok() { root.devMajor = linux.UNNAMED_MAJOR - root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()] + rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err) + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) + return nil, nil, err + } + root.devMinor = rootDevMinor root.ino = rootStat.Ino } else { root.devMajor = rootStat.DevMajor @@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 { return atomic.AddUint64(&fs.lastDirIno, 1) } +func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) { + fs.devMu.Lock() + defer fs.devMu.Unlock() + orig := layerDevNumber{layerMajor, layerMinor} + if minor, ok := fs.lowerDevMinors[orig]; ok { + return minor, nil + } + minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor() + if err != nil { + return 0, err + } + fs.lowerDevMinors[orig] = minor + return minor, nil +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -453,14 +474,14 @@ type dentry struct { // - If this dentry is copied-up, then wrappedMappable is the Mappable // obtained from a call to the current top layer's // FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil - // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil. + // (from a call to regularFileFD.ensureMappable()), it cannot become nil. // wrappedMappable is protected by mapsMu and dataMu. // // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is // accessed using atomic memory operations. - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` lowerMappings memmap.MappingSet - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` wrappedMappable memmap.Mappable isMappable uint32 @@ -484,6 +505,9 @@ func (fs *filesystem) newDentry() *dentry { } d.lowerVDs = d.inlineLowerVDs[:0] d.vfsd.Init(d) + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(d, "overlay.dentry") + } return d } @@ -583,6 +607,14 @@ func (d *dentry) destroyLocked(ctx context.Context) { panic("overlay.dentry.DecRef() called without holding a reference") } } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(d, "overlay.dentry") + } +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 853aee951..2b89a7a6d 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -19,14 +19,21 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" ) +func (d *dentry) isRegularFile() bool { + return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFREG +} + func (d *dentry) isSymlink() bool { return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK } @@ -40,7 +47,7 @@ func (d *dentry) readlink(ctx context.Context) (string, error) { } // +stateify savable -type nonDirectoryFD struct { +type regularFileFD struct { fileDescription // If copiedUp is false, cachedFD represents @@ -52,9 +59,13 @@ type nonDirectoryFD struct { copiedUp bool cachedFD *vfs.FileDescription cachedFlags uint32 + + // If copiedUp is false, lowerWaiters contains all waiter.Entries + // registered with cachedFD. lowerWaiters is protected by mu. + lowerWaiters map[*waiter.Entry]waiter.EventMask } -func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) { +func (fd *regularFileFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) { fd.mu.Lock() defer fd.mu.Unlock() wrappedFD, err := fd.currentFDLocked(ctx) @@ -65,7 +76,7 @@ func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescriptio return wrappedFD, nil } -func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) { +func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) { d := fd.dentry() statusFlags := fd.vfsfd.StatusFlags() if !fd.copiedUp && d.isCopiedUp() { @@ -87,10 +98,21 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip return nil, err } } + if len(fd.lowerWaiters) != 0 { + ready := upperFD.Readiness(^waiter.EventMask(0)) + for e, mask := range fd.lowerWaiters { + fd.cachedFD.EventUnregister(e) + upperFD.EventRegister(e, mask) + if ready&mask != 0 { + e.Callback.Callback(e) + } + } + } fd.cachedFD.DecRef(ctx) fd.copiedUp = true fd.cachedFD = upperFD fd.cachedFlags = statusFlags + fd.lowerWaiters = nil } else if fd.cachedFlags != statusFlags { if err := fd.cachedFD.SetStatusFlags(ctx, d.fs.creds, statusFlags); err != nil { return nil, err @@ -101,13 +123,13 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *nonDirectoryFD) Release(ctx context.Context) { +func (fd *regularFileFD) Release(ctx context.Context) { fd.cachedFD.DecRef(ctx) fd.cachedFD = nil } // OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *nonDirectoryFD) OnClose(ctx context.Context) error { +func (fd *regularFileFD) OnClose(ctx context.Context) error { // Linux doesn't define ovl_file_operations.flush at all (i.e. its // equivalent to OnClose is a no-op). We pass through to // fd.cachedFD.OnClose() without upgrading if fd.dentry() has been @@ -128,7 +150,7 @@ func (fd *nonDirectoryFD) OnClose(ctx context.Context) error { } // Stat implements vfs.FileDescriptionImpl.Stat. -func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { +func (fd *regularFileFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 { wrappedFD, err := fd.getCurrentFD(ctx) @@ -149,7 +171,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux } // Allocate implements vfs.FileDescriptionImpl.Allocate. -func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { +func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { wrappedFD, err := fd.getCurrentFD(ctx) if err != nil { return err @@ -159,7 +181,7 @@ func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uin } // SetStat implements vfs.FileDescriptionImpl.SetStat. -func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { +func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() mode := linux.FileMode(atomic.LoadUint32(&d.mode)) if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { @@ -191,12 +213,61 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) } // StatFS implements vfs.FileDescriptionImpl.StatFS. -func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) { +func (fd *regularFileFD) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.filesystem().statFS(ctx) } +// Readiness implements waiter.Waitable.Readiness. +func (fd *regularFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { + ctx := context.Background() + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + // TODO(b/171089913): Just use fd.cachedFD since Readiness can't return + // an error. This is obviously wrong, but at least consistent with + // VFS1. + log.Warningf("overlay.regularFileFD.Readiness: currentFDLocked failed: %v", err) + fd.mu.Lock() + wrappedFD = fd.cachedFD + wrappedFD.IncRef() + fd.mu.Unlock() + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *regularFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.mu.Lock() + defer fd.mu.Unlock() + wrappedFD, err := fd.currentFDLocked(context.Background()) + if err != nil { + // TODO(b/171089913): Just use fd.cachedFD since EventRegister can't + // return an error. This is obviously wrong, but at least consistent + // with VFS1. + log.Warningf("overlay.regularFileFD.EventRegister: currentFDLocked failed: %v", err) + wrappedFD = fd.cachedFD + } + wrappedFD.EventRegister(e, mask) + if !fd.copiedUp { + if fd.lowerWaiters == nil { + fd.lowerWaiters = make(map[*waiter.Entry]waiter.EventMask) + } + fd.lowerWaiters[e] = mask + } +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *regularFileFD) EventUnregister(e *waiter.Entry) { + fd.mu.Lock() + defer fd.mu.Unlock() + fd.cachedFD.EventUnregister(e) + if !fd.copiedUp { + delete(fd.lowerWaiters, e) + } +} + // PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { +func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { wrappedFD, err := fd.getCurrentFD(ctx) if err != nil { return 0, err @@ -206,7 +277,7 @@ func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, off } // Read implements vfs.FileDescriptionImpl.Read. -func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { +func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { // Hold fd.mu during the read to serialize the file offset. fd.mu.Lock() defer fd.mu.Unlock() @@ -218,7 +289,7 @@ func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts } // PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { +func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { wrappedFD, err := fd.getCurrentFD(ctx) if err != nil { return 0, err @@ -228,7 +299,7 @@ func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, of } // Write implements vfs.FileDescriptionImpl.Write. -func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { +func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { // Hold fd.mu during the write to serialize the file offset. fd.mu.Lock() defer fd.mu.Unlock() @@ -240,7 +311,7 @@ func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opt } // Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { +func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { // Hold fd.mu during the seek to serialize the file offset. fd.mu.Lock() defer fd.mu.Unlock() @@ -252,7 +323,7 @@ func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) } // Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *nonDirectoryFD) Sync(ctx context.Context) error { +func (fd *regularFileFD) Sync(ctx context.Context) error { fd.mu.Lock() if !fd.dentry().isCopiedUp() { fd.mu.Unlock() @@ -269,8 +340,18 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error { return wrappedFD.Sync(ctx) } +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *regularFileFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + return 0, err + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Ioctl(ctx, uio, args) +} + // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { if err := fd.ensureMappable(ctx, opts); err != nil { return err } @@ -278,7 +359,7 @@ func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOp } // ensureMappable ensures that fd.dentry().wrappedMappable is not nil. -func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error { +func (fd *regularFileFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error { d := fd.dentry() // Fast path if we already have a Mappable for the current top layer. diff --git a/pkg/sentry/fsimpl/overlay/save_restore.go b/pkg/sentry/fsimpl/overlay/save_restore.go new file mode 100644 index 000000000..054e17b17 --- /dev/null +++ b/pkg/sentry/fsimpl/overlay/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package overlay + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d, "overlay.dentry") + } +} diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 2e086e34c..5196a2a80 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "fd_dir_inode_refs.go", package = "proc", prefix = "fdDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdDirInode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "fd_info_dir_inode_refs.go", package = "proc", prefix = "fdInfoDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdInfoDirInode", }, @@ -30,7 +30,7 @@ go_template_instance( out = "subtasks_inode_refs.go", package = "proc", prefix = "subtasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "subtasksInode", }, @@ -41,7 +41,7 @@ go_template_instance( out = "task_inode_refs.go", package = "proc", prefix = "taskInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "taskInode", }, @@ -52,7 +52,7 @@ go_template_instance( out = "tasks_inode_refs.go", package = "proc", prefix = "tasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tasksInode", }, @@ -82,6 +82,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index fd70a07de..99abcab66 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -17,6 +17,7 @@ package proc import ( "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -24,10 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "proc" +const ( + // Name is the default filesystem name. + Name = "proc" + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType is the factory class for procfs. // @@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF if err != nil { return nil, nil, err } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + procfs := &filesystem{ devMinor: devMinor, } + procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) var cgroups map[string]string @@ -74,7 +92,7 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF cgroups = data.Cgroups } - inode := procfs.newTasksInode(k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, cgroups) var dentry kernfs.Dentry dentry.Init(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil @@ -94,11 +112,11 @@ type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource - Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) + Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) } -func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { - inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) +func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) return inode } @@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile { return &staticFile{StaticData: vfs.StaticData{Data: data}} } -func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { - return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ +func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { + return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) } diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index bad2fab4f..cb3c5e0fd 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers: cgroupControllers, } // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) subInode.EnableLeakCheck() @@ -84,7 +84,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { return offset, syserror.ENOENT diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index b63a4eca0..19011b010 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -64,6 +64,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}), + "mem": fs.newMemInode(task, fs.NextIno(), 0400), "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}), "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}), "net": fs.newTaskNetDir(task), @@ -89,7 +90,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) taskInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: taskInode, owner: task} @@ -144,7 +145,7 @@ var _ kernfs.Inode = (*taskOwnedInode)(nil) func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) return &taskOwnedInode{Inode: inode, owner: task} } @@ -152,7 +153,7 @@ func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linu func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero} - dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) + dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) return &taskOwnedInode{Inode: dir, owner: task} } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 2c80ac5c2..d268b44be 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -64,7 +64,7 @@ type fdDir struct { } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { @@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { produceSymlink: true, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Lookup implements kernfs.inodeDirectory.Lookup. @@ -209,7 +209,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kern task: task, fd: fd, } - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -264,7 +264,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode { task: task, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode @@ -288,8 +288,8 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements Inode.IterDirents. -func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 79f8b7e9f..ba71d0fde 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -249,7 +250,7 @@ type commInode struct { func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) + inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) return inode } @@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in return int64(srclen), nil } +var _ kernfs.Inode = (*memInode)(nil) + +// memInode implements kernfs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memInode struct { + kernfs.InodeAttrs + kernfs.InodeNoStatFS + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + task *kernel.Task + locks vfs.FileLocks +} + +func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { + // Note: credentials are overridden by taskOwnedInode. + inode := &memInode{task: task} + inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + return &taskOwnedInode{Inode: inode, owner: task} +} + +func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) + } + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) +} + +// Open implements kernfs.Inode.Open. +func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, f.task, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(f.task); err != nil { + return nil, err + } + fd := &memFD{} + if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +var _ vfs.FileDescriptionImpl = (*memFD)(nil) + +// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem. +// +// +stateify savable +type memFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + inode *memInode + + // mu guards the fields below. + mu sync.Mutex `state:"nosave"` + offset int64 +} + +// Init initializes memFD. +func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error { + fd.LockFD.Init(&inode.locks) + if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return err + } + fd.inode = inode + return nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + switch whence { + case linux.SEEK_SET: + case linux.SEEK_CUR: + offset += fd.offset + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.offset = offset + return offset, nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + m, err := getMMIncRef(fd.inode.task) + if err != nil { + return 0, nil + } + defer m.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.offset, opts) + fd.offset += n + fd.mu.Unlock() + return n, err +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *memFD) Release(context.Context) {} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} + // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -657,7 +814,7 @@ var _ kernfs.Inode = (*exeSymlink)(nil) func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &exeSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -733,7 +890,7 @@ var _ kernfs.Inode = (*cwdSymlink)(nil) func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &cwdSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri inode := &namespaceSymlink{task: task} // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) taskInode := &taskOwnedInode{Inode: inode, owner: task} return taskInode @@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir // Create a synthetic inode to represent the namespace. fs := mnt.Filesystem().Impl().(*filesystem) + nsInode := &namespaceInode{} + nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444) dentry := &kernfs.Dentry{} - dentry.Init(&fs.Filesystem, &namespaceInode{}) + dentry.Init(&fs.Filesystem, nsInode) vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry()) // Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1. mnt.IncRef() @@ -897,11 +1056,11 @@ type namespaceInode struct { var _ kernfs.Inode = (*namespaceInode)(nil) // Init initializes a namespace inode. -func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 3425e8698..5a9ee111f 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -57,33 +57,33 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]kernfs.Inode{ - "dev": fs.newInode(root, 0444, &netDevData{stack: stack}), - "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}), + "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}), + "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, if the file contains a header the stub is just the header // otherwise it is an empty file. - "arp": fs.newInode(root, 0444, newStaticFile(arp)), - "netlink": fs.newInode(root, 0444, newStaticFile(netlink)), - "netstat": fs.newInode(root, 0444, &netStatData{}), - "packet": fs.newInode(root, 0444, newStaticFile(packet)), - "protocols": fs.newInode(root, 0444, newStaticFile(protocols)), + "arp": fs.newInode(task, root, 0444, newStaticFile(arp)), + "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)), + "netstat": fs.newInode(task, root, 0444, &netStatData{}), + "packet": fs.newInode(task, root, 0444, newStaticFile(packet)), + "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)), // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": fs.newInode(root, 0444, newStaticFile(psched)), - "ptype": fs.newInode(root, 0444, newStaticFile(ptype)), - "route": fs.newInode(root, 0444, &netRouteData{stack: stack}), - "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}), - "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}), - "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}), + "psched": fs.newInode(task, root, 0444, newStaticFile(psched)), + "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)), + "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}), + "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}), + "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}), + "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}), } if stack.SupportsIPv6() { - contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile("")) - contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6)) + contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack}) + contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile("")) + contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k}) + contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6)) } } diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 3259c3732..b81ea14bf 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -62,19 +62,19 @@ type tasksInode struct { var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ - "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))), - "filesystems": fs.newInode(root, 0444, &filesystemsData{}), - "loadavg": fs.newInode(root, 0444, &loadavgData{}), - "sys": fs.newSysDir(root, k), - "meminfo": fs.newInode(root, 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), - "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), - "stat": fs.newInode(root, 0444, &statData{}), - "uptime": fs.newInode(root, 0444, &uptimeData{}), - "version": fs.newInode(root, 0444, &versionData{}), + "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), + "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), + "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), + "sys": fs.newSysDir(ctx, root, k), + "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}), + "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), + "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), + "stat": fs.newInode(ctx, root, 0444, &statData{}), + "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), + "version": fs.newInode(ctx, root, 0444, &versionData{}), } inode := &tasksInode{ @@ -82,7 +82,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace fs: fs, cgroupControllers: cgroupControllers, } - inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) @@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err // If it failed to parse, check if it's one of the special handled files. switch name { case selfName: - return i.newSelfSymlink(root), nil + return i.newSelfSymlink(ctx, root), nil case threadSelfName: - return i.newThreadSelfSymlink(root), nil + return i.newThreadSelfSymlink(ctx, root), nil } return nil, syserror.ENOENT } @@ -122,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { +func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 07c27cdd9..01b7a6678 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -43,9 +43,9 @@ type selfSymlink struct { var _ kernfs.Inode = (*selfSymlink)(nil) -func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &selfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } @@ -84,9 +84,9 @@ type threadSelfSymlink struct { var _ kernfs.Inode = (*threadSelfSymlink)(nil) -func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &threadSelfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 95420368d..7c7afdcfa 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -40,93 +40,93 @@ const ( ) // newSysDir returns the dentry corresponding to /proc/sys directory. -func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { - return fs.newStaticDir(root, map[string]kernfs.Inode{ - "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{ - "hostname": fs.newInode(root, 0444, &hostnameData{}), - "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)), - "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)), - "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)), +func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { + return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), + "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), + "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), }), - "vm": fs.newStaticDir(root, map[string]kernfs.Inode{ - "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}), - "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")), + "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}), + "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")), }), - "net": fs.newSysNetDir(root, k), + "net": fs.newSysNetDir(ctx, root, k), }) } // newSysNetDir returns the dentry corresponding to /proc/sys/net directory. -func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { +func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { var contents map[string]kernfs.Inode // TODO(gvisor.dev/issue/1833): Support for using the network stack in the // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ - "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}), + "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")), - "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")), - "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")), - "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")), - "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")), + "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), + "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), + "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), + "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")), // tcp_allowed_congestion_control tell the user what they are able to // do as an unprivledged process so we leave it empty. - "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")), - "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), - "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), + "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), + "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), // Many of the following stub files are features netstack doesn't // support. The unsupported features return "0" to indicate they are // disabled. - "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")), - "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")), - "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")), - "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")), - "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")), - "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")), + "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")), + "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")), + "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")), + "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")), }), - "core": fs.newStaticDir(root, map[string]kernfs.Inode{ - "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")), - "message_burst": fs.newInode(root, 0444, newStaticFile("10")), - "message_cost": fs.newInode(root, 0444, newStaticFile("5")), - "optmem_max": fs.newInode(root, 0444, newStaticFile("0")), - "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")), - "somaxconn": fs.newInode(root, 0444, newStaticFile("128")), - "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")), + "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")), + "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")), + "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")), + "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), }), } } - return fs.newStaticDir(root, contents) + return fs.newStaticDir(ctx, root, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 2582ababd..7ee6227a9 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -77,6 +77,7 @@ var ( "gid_map": linux.DT_REG, "io": linux.DT_REG, "maps": linux.DT_REG, + "mem": linux.DT_REG, "mountinfo": linux.DT_REG, "mounts": linux.DT_REG, "net": linux.DT_DIR, diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index cf91ea36c..fda1fa942 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // NewDentry constructs and returns a sockfs dentry. // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry { +func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry { fs := mnt.Filesystem().Impl().(*filesystem) // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) i := &inode{} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) + i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) d := &kernfs.Dentry{} d.Init(&fs.Filesystem, i) diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 906cd52cb..09043b572 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "dir_refs.go", package = "sys", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -28,6 +28,7 @@ go_library( "//pkg/coverage", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index 94366d429..b13f141a8 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -29,7 +29,7 @@ import ( func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode { k := &kcovInode{} - k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) + k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) return k } @@ -102,7 +102,7 @@ func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro func (fd *kcovFD) Release(ctx context.Context) { // kcov instances have reference counts in Linux, but this seems sufficient // for our purposes. - fd.kcov.Clear() + fd.kcov.Clear(ctx) } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1ad679830..7d2147141 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -18,6 +18,7 @@ package sys import ( "bytes" "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -29,9 +30,12 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "sysfs" -const defaultSysDirMode = linux.FileMode(0755) +const ( + // Name is the default filesystem name. + Name = "sysfs" + defaultSysDirMode = linux.FileMode(0755) + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType implements vfs.FilesystemType. // @@ -62,28 +66,40 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + fs := &filesystem{ devMinor: devMinor, } + fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "block": fs.newDir(creds, defaultSysDirMode, nil), - "bus": fs.newDir(creds, defaultSysDirMode, nil), - "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "power_supply": fs.newDir(creds, defaultSysDirMode, nil), + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil), }), - "dev": fs.newDir(creds, defaultSysDirMode, nil), - "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ + "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "cpu": cpuDir(ctx, fs, creds), }), }), - "firmware": fs.newDir(creds, defaultSysDirMode, nil), - "fs": fs.newDir(creds, defaultSysDirMode, nil), + "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), "kernel": kernelDir(ctx, fs, creds), - "module": fs.newDir(creds, defaultSysDirMode, nil), - "power": fs.newDir(creds, defaultSysDirMode, nil), + "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), }) var rootD kernfs.Dentry rootD.Init(&fs.Filesystem, root) @@ -94,14 +110,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs k := kernel.KernelFromContext(ctx) maxCPUCores := k.ApplicationCores() children := map[string]kernfs.Inode{ - "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), } for i := uint(0); i < maxCPUCores; i++ { - children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil) + children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil) } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { @@ -111,12 +127,12 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker var children map[string]kernfs.Inode if coverage.KcovAvailable() { children = map[string]kernfs.Inode{ - "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{ + "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ "kcov": fs.newKcovFile(ctx, creds), }), } } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } // Release implements vfs.FilesystemImpl.Release. @@ -140,9 +156,9 @@ type dir struct { locks vfs.FileLocks } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { d := &dir{} - d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) d.EnableLeakCheck() d.IncLinks(d.OrderedChildren.Populate(contents)) @@ -191,9 +207,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { +func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { c := &cpuFile{maxCores: maxCores} - c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) + c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) return c } diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 1813269e0..738c0c9cc 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -147,7 +147,12 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns FSContext: kernel.NewFSContextVFS2(root, cwd, 0022), FDTable: k.NewFDTable(), } - return k.TaskSet().NewTask(config) + t, err := k.TaskSet().NewTask(ctx, config) + if err != nil { + config.ThreadGroup.Release(ctx) + return nil, err + } + return t, nil } func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) { diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 5cd428d64..fe520b6fd 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -31,7 +31,7 @@ go_template_instance( out = "inode_refs.go", package = "tmpfs", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -48,6 +48,7 @@ go_library( "inode_refs.go", "named_pipe.go", "regular_file.go", + "save_restore.go", "socket_file.go", "symlink.go", "tmpfs.go", @@ -60,6 +61,7 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index ce4e3eda7..98680fde9 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -42,7 +42,7 @@ type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. - memFile *pgalloc.MemoryFile + memFile *pgalloc.MemoryFile `state:"nosave"` // memoryUsageKind is the memory accounting category under which pages backing // this regularFile's contents are accounted. @@ -92,7 +92,7 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, + memFile: fs.mfp.MemoryFile(), memoryUsageKind: usage.Tmpfs, seals: linux.F_SEAL_SEAL, } diff --git a/pkg/sentry/fsimpl/tmpfs/save_restore.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go new file mode 100644 index 000000000..b27f75cc2 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +// afterLoad is called by stateify. +func (rf *regularFile) afterLoad() { + rf.memFile = rf.inode.fs.mfp.MemoryFile() +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index e2a0aac69..4ce859d57 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -61,8 +61,9 @@ type FilesystemType struct{} type filesystem struct { vfsfs vfs.Filesystem - // memFile is used to allocate pages to for regular files. - memFile *pgalloc.MemoryFile + // mfp is used to allocate memory that stores regular file contents. mfp is + // immutable. + mfp pgalloc.MemoryFileProvider // clock is a realtime clock used to set timestamps in file operations. clock time.Clock @@ -106,8 +107,8 @@ type FilesystemOpts struct { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) - if memFileProvider == nil { + mfp := pgalloc.MemoryFileProviderFromContext(ctx) + if mfp == nil { panic("MemoryFileProviderFromContext returned nil") } @@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } clock := time.RealtimeClockFromContext(ctx) fs := filesystem{ - memFile: memFileProvider.MemoryFile(), + mfp: mfp, clock: clock, devMinor: devMinor, } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 0ca750281..e265be0ee 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -6,6 +6,7 @@ go_library( name = "verity", srcs = [ "filesystem.go", + "save_restore.go", "verity.go", ], visibility = ["//pkg/sentry:internal"], @@ -15,6 +16,7 @@ go_library( "//pkg/fspath", "//pkg/marshal/primitive", "//pkg/merkletree", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", @@ -38,10 +40,12 @@ go_test( "//pkg/context", "//pkg/fspath", "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", "//pkg/sentry/vfs", + "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 03da505e1..2f6050cfd 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -192,7 +192,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -201,7 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -215,7 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -233,7 +233,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -243,7 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -256,7 +256,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -267,20 +267,22 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // Verify returns with success. var buf bytes.Buffer if _, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, ReadOffset: int64(offset), - ReadSize: int64(merkletree.DigestSize()), + ReadSize: int64(merkletree.DigestSize(linux.FS_VERITY_HASH_ALG_SHA256)), Expected: parent.hash, DataAndTreeInSameFile: true, }); err != nil && err != io.EOF { - return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. @@ -312,7 +314,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err @@ -324,7 +326,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat }) if err == syserror.ENODATA { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -332,7 +334,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -342,14 +344,16 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat var buf bytes.Buffer params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - ReadOffset: 0, + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + ReadOffset: 0, // Set read size to 0 so only the metadata is verified. ReadSize: 0, Expected: d.hash, @@ -360,17 +364,57 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID + d.size = uint32(size) return nil } // Preconditions: fs.renameMu must be locked. d.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { + // If verity is enabled on child, we should check again whether + // the file and the corresponding Merkle tree are as expected, + // in order to catch deletion/renaming after the last time it's + // accessed. + if child.verityEnabled() { + vfsObj := fs.vfsfs.VirtualFilesystem() + // Get the path to the child dentry. This is only used + // to provide path information in failure case. + path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD) + if err != nil { + return nil, err + } + + childVD, err := parent.getLowerAt(ctx, vfsObj, name) + if err == syserror.ENOENT { + // The file was previously accessed. If the + // file does not exist now, it indicates an + // unexpected modification to the file system. + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + } + if err != nil { + return nil, err + } + defer childVD.DecRef(ctx) + + childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + // The Merkle tree file was previous accessed. If it + // does not exist now, it indicates an unexpected + // modification to the file system. + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + } + if err != nil { + return nil, err + } + + defer childMerkleVD.DecRef(ctx) + } + // If enabling verification on files/directories is not allowed // during runtime, all cached children are already verified. If // runtime enable is allowed and the parent directory is @@ -418,13 +462,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() - childFilename := fspath.Parse(name) - childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: childFilename, - }, &vfs.GetDentryOptions{}) - + childVD, childErr := parent.getLowerAt(ctx, vfsObj, name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childErr != nil && childErr != syserror.ENOENT { @@ -437,13 +475,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, defer childVD.DecRef(ctx) } - childMerkleFilename := merklePrefix + name - childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) - + childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { @@ -472,7 +504,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // corresponding Merkle tree is found. This indicates an // unexpected modification to the file system that // removed/renamed the child. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) } else if childErr == nil && childMerkleErr == syserror.ENOENT { // If in allowRuntimeEnable mode, and the Merkle tree file is // not created yet, we create an empty Merkle tree file, so that @@ -488,7 +520,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ Root: parent.lowerVD, Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), + Path: fspath.Parse(merklePrefix + name), }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT, Mode: 0644, @@ -497,11 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, return nil, err } childMerkleFD.DecRef(ctx) - childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) if err != nil { return nil, err } @@ -509,7 +537,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // If runtime enable is not allowed. This indicates an // unexpected modification to the file system that // removed/renamed the Merkle tree file. - return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) } } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { // Both the child and the corresponding Merkle tree are missing. @@ -518,7 +546,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // TODO(b/167752508): Investigate possible ways to differentiate // cases that both files are deleted from cases that they never // exist in the file system. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) } mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) @@ -762,7 +790,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -785,7 +813,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -810,7 +838,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -828,7 +856,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go new file mode 100644 index 000000000..4a161163c --- /dev/null +++ b/pkg/sentry/fsimpl/verity/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d, "verity.dentry") + } +} diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 8dc9e26bc..92ca6ca6b 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -23,6 +23,7 @@ package verity import ( "fmt" + "math" "strconv" "sync/atomic" @@ -31,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -153,10 +155,10 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. -func alertIntegrityViolation(err error, msg string) error { +// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +func alertIntegrityViolation(msg string) error { if noCrashOnVerificationFailure { - return err + return syserror.EIO } panic(msg) } @@ -236,7 +238,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") + return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") } d.lowerMerkleVD = lowerMerkleVD @@ -289,11 +291,12 @@ type dentry struct { // fs is the owning filesystem. fs is immutable. fs *filesystem - // mode, uid and gid are the file mode, owner, and group of the file in - // the underlying file system. + // mode, uid, gid and size are the file mode, owner, group, and size of + // the file in the underlying file system. mode uint32 uid uint32 gid uint32 + size uint32 // parent is the dentry corresponding to this dentry's parent directory. // name is this dentry's name in parent. If this dentry is a filesystem @@ -331,6 +334,9 @@ func (fs *filesystem) newDentry() *dentry { fs: fs, } d.vfsd.Init(d) + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(d, "verity.dentry") + } return d } @@ -393,6 +399,9 @@ func (d *dentry) destroyLocked(ctx context.Context) { if d.lowerVD.Ok() { d.lowerVD.DecRef(ctx) } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(d, "verity.dentry") + } if d.lowerMerkleVD.Ok() { d.lowerMerkleVD.DecRef(ctx) @@ -412,6 +421,11 @@ func (d *dentry) destroyLocked(ctx context.Context) { } } +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { //TODO(b/159261227): Implement InotifyWithParent. @@ -448,6 +462,16 @@ func (d *dentry) verityEnabled() bool { return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } +// getLowerAt returns the dentry in the underlying file system, which is +// represented by filename relative to d. +func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) { + return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(filename), + }, &vfs.GetDentryOptions{}) +} + func (d *dentry) readlink(ctx context.Context) (string, error) { return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: d.lowerVD, @@ -489,6 +513,10 @@ type fileDescription struct { // directory that contains the current file/directory. This is only used // if allowRuntimeEnable is set to true. parentMerkleWriter *vfs.FileDescription + + // off is the file offset. off is protected by mu. + mu sync.Mutex `state:"nosave"` + off int64 } // Release implements vfs.FileDescriptionImpl.Release. @@ -524,6 +552,32 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + n := int64(0) + switch whence { + case linux.SEEK_SET: + // use offset as specified + case linux.SEEK_CUR: + n = fd.off + case linux.SEEK_END: + n = int64(fd.d.size) + default: + return 0, syserror.EINVAL + } + if offset > math.MaxInt64-n { + return 0, syserror.EINVAL + } + offset += n + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + // generateMerkle generates a Merkle tree file for fd. If fd points to a file // /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash // of the generated Merkle tree and the data size is returned. If fd points to @@ -546,6 +600,8 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, params := &merkletree.GenerateParams{ TreeReader: &merkleReader, TreeWriter: &merkleWriter, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, } switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { @@ -611,7 +667,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } hash, dataSize, err := fd.generateMerkle(ctx) @@ -657,6 +713,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // measureVerity returns the hash of fd, saved in verityDigest. func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } var metadata linux.DigestMetadata // If allowRuntimeEnable is true, an empty fd.d.hash indicates that @@ -667,7 +726,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found") + return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -702,6 +761,9 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag } t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } _, err := primitive.CopyInt32Out(t, flags, f) return 0, err } @@ -722,6 +784,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. } } +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Implement Read with PRead by setting offset. + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.mu.Unlock() + return n, err +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // No need to verify if the file is not enabled yet in @@ -742,7 +814,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -752,7 +824,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } dataReader := vfs.FileReadWriteSeeker{ @@ -766,25 +838,37 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of } n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, ReadOffset: offset, ReadSize: dst.NumBytes(), Expected: fd.d.hash, DataAndTreeInSameFile: false, }) if err != nil { - return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index e301d35f5..c647cbfd3 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -25,10 +25,12 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -41,11 +43,18 @@ const maxDataSize = 100000 // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) { +func newVerityRoot(t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("testutil.Boot: %v", err) + } + + ctx := k.SupervisorContext() + rand.Seed(time.Now().UnixNano()) vfsObj := &vfs.VirtualFilesystem{} if err := vfsObj.Init(ctx); err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -67,16 +76,26 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v }, }) if err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) } root := mntns.Root() root.IncRef() + + // Use lowerRoot in the task as we modify the lower file system + // directly in many tests. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot) + if err != nil { + t.Fatalf("testutil.CreateTask: %v", err) + } + t.Helper() t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) }) - return vfsObj, root, nil + return vfsObj, root, task, nil } // newFileFD creates a new file in the verity mount, and returns the FD. The FD @@ -145,8 +164,7 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. func TestOpen(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -180,11 +198,10 @@ func TestOpen(t *testing.T) { } } -// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file -// succeeds after enabling verity for it. -func TestReadUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) +// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity +// file succeeds after enabling verity for it. +func TestPReadUnmodifiedFileSucceeds(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -213,11 +230,42 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } } +// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity +// file succeeds after enabling verity for it. +func TestReadUnmodifiedFileSucceeds(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } +} + // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -248,10 +296,10 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) { } } -// TestModifiedFileFails ensures that read from a modified verity file fails. -func TestModifiedFileFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) +// TestPReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestPReadModifiedFileFails(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -289,15 +337,59 @@ func TestModifiedFileFails(t *testing.T) { // Confirm that read from the modified file fails. buf := make([]byte, size) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded with modified file") + t.Fatalf("fd.PRead succeeded, expected failure") + } +} + +// TestReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestReadModifiedFileFails(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.Read succeeded, expected failure") } } // TestModifiedMerkleFails ensures that read from a verity file fails if the // corresponding Merkle tree file is modified. func TestModifiedMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -350,8 +442,7 @@ func TestModifiedMerkleFails(t *testing.T) { // verity enabled directory fails if the hashes related to the target file in // the parent Merkle tree file is modified. func TestModifiedParentMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -428,8 +519,7 @@ func TestModifiedParentMerkleFails(t *testing.T) { // TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file // succeeds after enabling verity for it. func TestUnmodifiedStatSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -455,8 +545,7 @@ func TestUnmodifiedStatSucceeds(t *testing.T) { // TestModifiedStatFails checks that getting stat for a file with modified stat // should fail. func TestModifiedStatFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -489,3 +578,123 @@ func TestModifiedStatFails(t *testing.T) { t.Errorf("fd.Stat succeeded when it should fail") } } + +// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed +// verity enabled file or the corresponding Merkle tree file fails with the +// verify error. +func TestOpenDeletedFileFails(t *testing.T) { + testCases := []struct { + // Tests removing files is remove is true. Otherwise tests + // renaming files. + remove bool + // The original file is removed/renamed if changeFile is true. + changeFile bool + // The Merkle tree file is removed/renamed if changeMerkleFile + // is true. + changeMerkleFile bool + }{ + { + remove: true, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: true, + changeFile: false, + changeMerkleFile: true, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, + }, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD + if tc.remove { + if tc.changeFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } else { + newFilename := "renamed-test-file" + if tc.changeFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("RenameAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } + }) + } +} diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index fbe6d6aa6..f31277d30 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,9 +32,13 @@ type Stack interface { InterfaceAddrs() map[int32][]InterfaceAddr // AddInterfaceAddr adds an address to the network interface identified by - // index. + // idx. AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // RemoveInterfaceAddr removes an address from the network interface + // identified by idx. + RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 1779cc6f3..9ebeba8a3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -15,6 +15,9 @@ package inet import ( + "bytes" + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } +// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { + interfaceAddrs, ok := s.InterfaceAddrsMap[idx] + if !ok { + return fmt.Errorf("unknown idx: %d", idx) + } + + var filteredAddrs []InterfaceAddr + for _, interfaceAddr := range interfaceAddrs { + if !bytes.Equal(interfaceAddr.Addr, addr.Addr) { + filteredAddrs = append(filteredAddrs, addr) + } + } + s.InterfaceAddrsMap[idx] = filteredAddrs + + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 9a24c6bdb..90dd4a047 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -79,7 +79,7 @@ go_template_instance( out = "fd_table_refs.go", package = "kernel", prefix = "FDTable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FDTable", }, @@ -90,7 +90,7 @@ go_template_instance( out = "fs_context_refs.go", package = "kernel", prefix = "FSContext", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FSContext", }, @@ -101,7 +101,7 @@ go_template_instance( out = "ipc_namespace_refs.go", package = "kernel", prefix = "IPCNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "IPCNamespace", }, @@ -112,7 +112,7 @@ go_template_instance( out = "process_group_refs.go", package = "kernel", prefix = "ProcessGroup", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ProcessGroup", }, @@ -123,7 +123,7 @@ go_template_instance( out = "session_refs.go", package = "kernel", prefix = "Session", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Session", }, @@ -218,6 +218,7 @@ go_library( "//pkg/amutex", "//pkg/bits", "//pkg/bpf", + "//pkg/cleanup", "//pkg/context", "//pkg/coverage", "//pkg/cpuid", @@ -228,7 +229,7 @@ go_library( "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 1b9721534..0ddbe5ff6 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -19,7 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs_vfs2" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" ) @@ -27,7 +27,7 @@ import ( // +stateify savable type abstractEndpoint struct { ep transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter name string ns *AbstractSocketNamespace } @@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace { // its backing socket. type boundEndpoint struct { transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter } // Release implements transport.BoundEndpoint.Release. @@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp // // When the last reference managed by socket is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() @@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran // Remove removes the specified socket at name from the abstract socket // namespace, if it has not yet been replaced. -func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) { +func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) { a.mu.Lock() defer a.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 0ec7344cd..7aba31587 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() - f.init() // Initialize table. + f.initNoLeakCheck() // Initialize table. f.used = 0 for fd, d := range m { if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { @@ -240,6 +240,10 @@ func (f *FDTable) String() string { case fileVFS2 != nil: vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem() + vd := fileVFS2.VirtualDentry() + if vd.Dentry() == nil { + panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2)) + } name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) if err != nil { fmt.Fprintf(&buf, "<err: %v>\n", err) diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index da79e6627..3476551f3 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -31,14 +31,21 @@ type descriptorTable struct { slice unsafe.Pointer `state:".(map[int32]*descriptor)"` } -// init initializes the table. +// initNoLeakCheck initializes the table without enabling leak checking. // -// TODO(gvisor.dev/1486): Enable leak check for FDTable. -func (f *FDTable) init() { +// This is used when loading an FDTable after S/R, during which the ref count +// object itself will enable leak checking if necessary. +func (f *FDTable) initNoLeakCheck() { var slice []unsafe.Pointer // Empty slice. atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) } +// init initializes the table with leak checking. +func (f *FDTable) init() { + f.initNoLeakCheck() + f.EnableLeakCheck() +} + // get gets a file entry. // // The boolean indicates whether this was in range. diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index d46d1e1c1..41fb2a784 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext { f.root.IncRef() } - return &FSContext{ + ctx := &FSContext{ cwd: f.cwd, root: f.root, cwdVFS2: f.cwdVFS2, rootVFS2: f.rootVFS2, umask: f.umask, } + ctx.EnableLeakCheck() + return ctx } // WorkingDirectory returns the current working directory. @@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() - f.cwd.IncRef() + if f.cwd != nil { + f.cwd.IncRef() + } return f.cwd } // WorkingDirectoryVFS2 returns the current working directory. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.cwdVFS2.IncRef() + if f.cwdVFS2.Ok() { + f.cwdVFS2.IncRef() + } return f.cwdVFS2 } @@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent { // RootDirectoryVFS2 returns the current filesystem root. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.rootVFS2.IncRef() + if f.rootVFS2.Ok() { + f.rootVFS2.IncRef() + } return f.rootVFS2 } diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go index 3f34ee0db..b87e40dd1 100644 --- a/pkg/sentry/kernel/ipc_namespace.go +++ b/pkg/sentry/kernel/ipc_namespace.go @@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry { return i.shms } -// DecRef implements refs_vfs2.RefCounter.DecRef. +// DecRef implements refsvfs2.RefCounter.DecRef. func (i *IPCNamespace) DecRef(ctx context.Context) { i.IPCNamespaceRefs.DecRef(func() { i.shms.Release(ctx) diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go index 060c056df..4fcdfc541 100644 --- a/pkg/sentry/kernel/kcov.go +++ b/pkg/sentry/kernel/kcov.go @@ -199,23 +199,25 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error { } kcov.mode = linux.KCOV_MODE_INIT kcov.owningTask = nil - kcov.mappable = nil + if kcov.mappable != nil { + kcov.mappable.DecRef(ctx) + kcov.mappable = nil + } return nil } // Clear resets the mode and clears the owning task and memory mapping for kcov. // It is called when the fd corresponding to kcov is closed. Note that the mode // needs to be set so that the next call to kcov.TaskWork() will exit early. -func (kcov *Kcov) Clear() { +func (kcov *Kcov) Clear(ctx context.Context) { kcov.mu.Lock() - kcov.clearLocked() - kcov.mu.Unlock() -} - -func (kcov *Kcov) clearLocked() { kcov.mode = linux.KCOV_MODE_INIT kcov.owningTask = nil - kcov.mappable = nil + if kcov.mappable != nil { + kcov.mappable.DecRef(ctx) + kcov.mappable = nil + } + kcov.mu.Unlock() } // OnTaskExit is called when the owning task exits. It is similar to @@ -254,6 +256,7 @@ func (kcov *Kcov) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro // will look different under /proc/[pid]/maps than they do on Linux. kcov.mappable = mm.NewSpecialMappable(fmt.Sprintf("[kcov:%d]", t.ThreadID()), kcov.mfp, fr) } + kcov.mappable.IncRef() opts.Mappable = kcov.mappable opts.MappingIdentity = kcov.mappable return nil diff --git a/pkg/sentry/kernel/kcov_unsafe.go b/pkg/sentry/kernel/kcov_unsafe.go index 6f64022eb..6f8a0266b 100644 --- a/pkg/sentry/kernel/kcov_unsafe.go +++ b/pkg/sentry/kernel/kcov_unsafe.go @@ -20,9 +20,9 @@ import ( "gvisor.dev/gvisor/pkg/safemem" ) -// countBlock provides a safemem.BlockSeq for k.count. +// countBlock provides a safemem.BlockSeq for kcov.count. // // Like k.count, the block returned is protected by k.mu. -func (k *Kcov) countBlock() safemem.BlockSeq { - return safemem.BlockSeqOf(safemem.BlockFromSafePointer(unsafe.Pointer(&k.count), int(unsafe.Sizeof(k.count)))) +func (kcov *Kcov) countBlock() safemem.BlockSeq { + return safemem.BlockSeqOf(safemem.BlockFromSafePointer(unsafe.Pointer(&kcov.count), int(unsafe.Sizeof(kcov.count)))) } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 652cbb732..9b2be44d4 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -39,6 +39,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/eventchannel" @@ -340,7 +341,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { return fmt.Errorf("Timekeeper is nil") } if args.Timekeeper.clocks == nil { - return fmt.Errorf("Must call Timekeeper.SetClocks() before Kernel.Init()") + return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()") } if args.RootUserNamespace == nil { return fmt.Errorf("RootUserNamespace is nil") @@ -365,7 +366,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.useHostCores = true maxCPU, err := hostcpu.MaxPossibleCPU() if err != nil { - return fmt.Errorf("Failed to get maximum CPU number: %v", err) + return fmt.Errorf("failed to get maximum CPU number: %v", err) } minAppCores := uint(maxCPU) + 1 if k.applicationCores < minAppCores { @@ -429,9 +430,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w wire.Writer) error { +func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { saveStart := time.Now() - ctx := k.SupervisorContext() // Do not allow other Kernel methods to affect it while it's being saved. k.extMu.Lock() @@ -445,38 +445,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error { k.mf.StartEvictions() k.mf.WaitForEvictions() - // Flush write operations on open files so data reaches backing storage. - // This must come after MemoryFile eviction since eviction may cause file - // writes. - if err := k.tasks.flushWritesToFiles(ctx); err != nil { - return err - } + if VFS2Enabled { + // Discard unsavable mappings, such as those for host file descriptors. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } - // Remove all epoll waiter objects from underlying wait queues. - // NOTE: for programs to resume execution in future snapshot scenarios, - // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters(ctx) + // Prepare filesystems for saving. This must be done after + // invalidateUnsavableMappings(), since dropping memory mappings may + // affect filesystem state (e.g. page cache reference counts). + if err := k.vfs.PrepareSave(ctx); err != nil { + return err + } + } else { + // Flush cached file writes to backing storage. This must come after + // MemoryFile eviction since eviction may cause file writes. + if err := k.flushWritesToFiles(ctx); err != nil { + return err + } - // Clear the dirent cache before saving because Dirents must be Loaded in a - // particular order (parents before children), and Loading dirents from a cache - // breaks that order. - if err := k.flushMountSourceRefs(ctx); err != nil { - return err - } + // Remove all epoll waiter objects from underlying wait queues. + // NOTE: for programs to resume execution in future snapshot scenarios, + // we will need to re-establish these waiter objects after saving. + k.tasks.unregisterEpollWaiters(ctx) - // Ensure that all inode and mount release operations have completed. - fs.AsyncBarrier() + // Clear the dirent cache before saving because Dirents must be Loaded in a + // particular order (parents before children), and Loading dirents from a cache + // breaks that order. + if err := k.flushMountSourceRefs(ctx); err != nil { + return err + } - // Once all fs work has completed (flushed references have all been released), - // reset mount mappings. This allows individual mounts to save how inodes map - // to filesystem resources. Without this, fs.Inodes cannot be restored. - fs.SaveInodeMappings() + // Ensure that all inode and mount release operations have completed. + fs.AsyncBarrier() - // Discard unsavable mappings, such as those for host file descriptors. - // This must be done after waiting for "asynchronous fs work", which - // includes async I/O that may touch application memory. - if err := k.invalidateUnsavableMappings(ctx); err != nil { - return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + // Once all fs work has completed (flushed references have all been released), + // reset mount mappings. This allows individual mounts to save how inodes map + // to filesystem resources. Without this, fs.Inodes cannot be restored. + fs.SaveInodeMappings() + + // Discard unsavable mappings, such as those for host file descriptors. + // This must be done after waiting for "asynchronous fs work", which + // includes async I/O that may touch application memory. + // + // TODO(gvisor.dev/issue/1624): This rationale is believed to be + // obsolete since AIO callbacks are now waited-for by Kernel.Pause(), + // but this order is conservatively retained for VFS1. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } } // Save the CPUID FeatureSet before the rest of the kernel so we can @@ -485,14 +502,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { + if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - stats, err := state.Save(k.SupervisorContext(), w, k) + stats, err := state.Save(ctx, w, k) if err != nil { return err } @@ -501,7 +518,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { + if err := k.mf.SaveTo(ctx, w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -513,11 +530,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. +// +// Preconditions: !VFS2Enabled. func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { - if VFS2Enabled { - return nil // Not relevant. - } - // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -560,13 +575,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi return err } -func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return nil - } - - return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { +// Preconditions: !VFS2Enabled. +func (k *Kernel) flushWritesToFiles(ctx context.Context) error { + return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil } @@ -588,37 +599,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { }) } -// Preconditions: The kernel must be paused. -func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { - invalidated := make(map[*mm.MemoryManager]struct{}) - k.tasks.mu.RLock() - defer k.tasks.mu.RUnlock() - for t := range k.tasks.Root.tids { - // We can skip locking Task.mu here since the kernel is paused. - if mm := t.tc.MemoryManager; mm != nil { - if _, ok := invalidated[mm]; !ok { - if err := mm.InvalidateUnsavable(ctx); err != nil { - return err - } - invalidated[mm] = struct{}{} - } - } - // I really wish we just had a sync.Map of all MMs... - if r, ok := t.runState.(*runSyscallAfterExecStop); ok { - if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { - return err - } - } - } - return nil -} - +// Preconditions: !VFS2Enabled. func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return - } - ts.mu.RLock() defer ts.mu.RUnlock() @@ -643,8 +625,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { } } +// Preconditions: The kernel must be paused. +func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { + invalidated := make(map[*mm.MemoryManager]struct{}) + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + for t := range k.tasks.Root.tids { + // We can skip locking Task.mu here since the kernel is paused. + if mm := t.tc.MemoryManager; mm != nil { + if _, ok := invalidated[mm]; !ok { + if err := mm.InvalidateUnsavable(ctx); err != nil { + return err + } + invalidated[mm] = struct{}{} + } + } + // I really wish we just had a sync.Map of all MMs... + if r, ok := t.runState.(*runSyscallAfterExecStop); ok { + if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { + return err + } + } + } + return nil +} + // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -655,7 +662,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { + if _, err := state.Load(ctx, r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -670,7 +677,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the kernel state. kernelStart := time.Now() - stats, err := state.Load(k.SupervisorContext(), r, k) + stats, err := state.Load(ctx, r, k) if err != nil { return err } @@ -683,7 +690,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { + if err := k.mf.LoadFrom(ctx, r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -695,11 +702,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock net.Resume() } - // Ensure that all pending asynchronous work is complete: - // - namedpipe opening - // - inode file opening - if err := fs.AsyncErrorBarrier(); err != nil { - return err + if VFS2Enabled { + if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil { + return err + } + } else { + // Ensure that all pending asynchronous work is complete: + // - namedpipe opening + // - inode file opening + if err := fs.AsyncErrorBarrier(); err != nil { + return err + } } tcpip.AsyncLoading.Wait() @@ -966,6 +979,10 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, } tg := k.NewThreadGroup(mntns, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits) + cu := cleanup.Make(func() { + tg.Release(ctx) + }) + defer cu.Clean() // Check which file to start from. switch { @@ -1025,13 +1042,14 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, MountNamespaceVFS2: mntnsVFS2, ContainerID: args.ContainerID, } - t, err := k.tasks.NewTask(config) + t, err := k.tasks.NewTask(ctx, config) if err != nil { return nil, 0, err } t.traceExecEvent(tc) // Simulate exec for tracing. // Success. + cu.Release() tgid := k.tasks.Root.IDOfThreadGroup(tg) if k.globalInit == nil { k.globalInit = tg diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index f61039f5b..d96bf253b 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -33,6 +33,8 @@ import ( // VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should // not be copied. +// +// +stateify savable type VFSPipe struct { // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements // non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to // other FileDescriptions for splice(2) and tee(2). +// +// +stateify savable type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -237,8 +241,7 @@ func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // PipeSize implements fcntl(F_GETPIPE_SZ). func (fd *VFSPipeFD) PipeSize() int64 { - // Inline Pipe.FifoSize() rather than calling it with nil Context and - // fs.File and ignoring the returned error (which is always nil). + // Inline Pipe.FifoSize() since we don't have a fs.File. fd.pipe.mu.Lock() defer fd.pipe.mu.Unlock() return fd.pipe.max diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index c00fa1138..c39ecfb8f 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -283,6 +283,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File return nil } +// GetStat extracts semid_ds information from the set. +func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // "The calling process must have read permission on the semaphore set." + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return nil, syserror.EACCES + } + + ds := &linux.SemidDS{ + SemPerm: linux.IPCPerm{ + Key: uint32(s.key), + UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)), + GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)), + CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)), + CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)), + Mode: uint16(s.perms.LinuxMode()), + Seq: 0, // IPC sequence not supported. + }, + SemOTime: s.opTime.TimeT(), + SemCTime: s.changeTime.TimeT(), + SemNSems: uint64(s.Size()), + } + return ds, nil +} + // SetVal overrides a semaphore value, waking up waiters as needed. func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error { if val < 0 || val > valueMax { @@ -320,7 +347,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti } for _, val := range vals { - if val < 0 || val > valueMax { + if val > valueMax { return syserror.ERANGE } } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index f8a382fd8..80a592c8f 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "shm_refs.go", package = "shm", prefix = "Shm", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Shm", }, @@ -27,7 +27,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 7a053f369..682080c14 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -206,6 +207,10 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } else { ipcns.IncRef() } + cu := cleanup.Make(func() { + ipcns.DecRef(t) + }) + defer cu.Clean() netns := t.NetworkNamespace() if opts.NewNetworkNamespace { @@ -216,13 +221,18 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { mntnsVFS2 := t.mountNamespaceVFS2 if mntnsVFS2 != nil { mntnsVFS2.IncRef() + cu.Add(func() { + mntnsVFS2.DecRef(t) + }) } tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace) if err != nil { - ipcns.DecRef(t) return 0, nil, err } + cu.Add(func() { + tc.release() + }) // clone() returns 0 in the child. tc.Arch.SetReturn(0) if opts.Stack != 0 { @@ -230,7 +240,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } if opts.SetTLS { if !tc.Arch.SetTLS(uintptr(opts.TLS)) { - ipcns.DecRef(t) return 0, nil, syserror.EPERM } } @@ -299,11 +308,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } else { cfg.InheritParent = t } - nt, err := t.tg.pidns.owner.NewTask(cfg) + nt, err := t.tg.pidns.owner.NewTask(t, cfg) + // If NewTask succeeds, we transfer references to nt. If NewTask fails, it does + // the cleanup for us. + cu.Release() if err != nil { - if opts.NewThreadGroup { - tg.release(t) - } return 0, nil, err } diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 239551eb6..ce7b9641d 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -286,7 +286,7 @@ func (*runExitMain) execute(t *Task) taskRunState { // If this is the last task to exit from the thread group, release the // thread group's resources. if lastExiter { - t.tg.release(t) + t.tg.Release(t) } // Detach tracees. diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 6e2ff573a..8e28230cc 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -98,15 +99,18 @@ type TaskConfig struct { // NewTask creates a new task defined by cfg. // // NewTask does not start the returned task; the caller must call Task.Start. -func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) { +// +// If successful, NewTask transfers references held by cfg to the new task. +// Otherwise, NewTask releases them. +func (ts *TaskSet) NewTask(ctx context.Context, cfg *TaskConfig) (*Task, error) { t, err := ts.newTask(cfg) if err != nil { cfg.TaskContext.release() - cfg.FSContext.DecRef(t) - cfg.FDTable.DecRef(t) - cfg.IPCNamespace.DecRef(t) + cfg.FSContext.DecRef(ctx) + cfg.FDTable.DecRef(ctx) + cfg.IPCNamespace.DecRef(ctx) if cfg.MountNamespaceVFS2 != nil { - cfg.MountNamespaceVFS2.DecRef(t) + cfg.MountNamespaceVFS2.DecRef(ctx) } return nil, err } diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 0b34c0099..a183b28c1 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -307,8 +308,8 @@ func (tg *ThreadGroup) Limits() *limits.LimitSet { return tg.limits } -// release releases the thread group's resources. -func (tg *ThreadGroup) release(t *Task) { +// Release releases the thread group's resources. +func (tg *ThreadGroup) Release(ctx context.Context) { // Timers must be destroyed without holding the TaskSet or signal mutexes // since timers send signals with Timer.mu locked. tg.itimerRealTimer.Destroy() @@ -325,7 +326,7 @@ func (tg *ThreadGroup) release(t *Task) { it.DestroyTimer() } if tg.mounts != nil { - tg.mounts.DecRef(t) + tg.mounts.DecRef(ctx) } } diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index d4610ec3b..98af2cc38 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -194,6 +194,10 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Too many phdrs (%d): total size %d > %d", hdr.Phnum, totalPhdrSize, maxTotalPhdrSize) return elfInfo{}, syserror.ENOEXEC } + if int64(hdr.Phoff) < 0 || int64(hdr.Phoff+uint64(totalPhdrSize)) < 0 { + ctx.Infof("Unsupported phdr offset %d", hdr.Phoff) + return elfInfo{}, syserror.ENOEXEC + } phdrBuf := make([]byte, totalPhdrSize) _, err = f.ReadFull(ctx, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff)) @@ -437,6 +441,10 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in ctx.Infof("PT_INTERP path too big: %v", phdr.Filesz) return loadedELF{}, syserror.ENOEXEC } + if int64(phdr.Off) < 0 || int64(phdr.Off+phdr.Filesz) < 0 { + ctx.Infof("Unsupported PT_INTERP offset %d", phdr.Off) + return loadedELF{}, syserror.ENOEXEC + } path := make([]byte, phdr.Filesz) _, err := f.ReadFull(ctx, usermem.BytesIOSequence(path), int64(phdr.Off)) diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b4a47ccca..6dbeccfe2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -78,7 +78,7 @@ go_template_instance( out = "aio_mappable_refs.go", package = "mm", prefix = "aioMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "aioMappable", }, @@ -89,7 +89,7 @@ go_template_instance( out = "special_mappable_refs.go", package = "mm", prefix = "SpecialMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SpecialMappable", }, @@ -127,6 +127,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safecopy", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 7a3311a70..5b09b9feb 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -83,6 +83,7 @@ go_library( ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", "//pkg/context", "//pkg/log", "//pkg/memutil", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 626d1eaa4..7c297fb9e 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -29,6 +29,7 @@ import ( "syscall" "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" @@ -224,6 +225,18 @@ type usageInfo struct { refs uint64 } +// canCommit returns true if the tracked region can be committed. +func (u *usageInfo) canCommit() bool { + // refs must be greater than 0 because we assume that reclaimable pages + // (that aren't already known to be committed) are not committed. This + // isn't necessarily true, even after the reclaimer does Decommit(), + // because the kernel may subsequently back the hugepage-sized region + // containing the decommitted page with a hugepage. However, it's + // consistent with our treatment of unallocated pages, which have the same + // property. + return !u.knownCommitted && u.refs != 0 +} + // An EvictableMemoryUser represents a user of MemoryFile-allocated memory that // may be asked to deallocate that memory in the presence of memory pressure. type EvictableMemoryUser interface { @@ -828,6 +841,11 @@ func (f *MemoryFile) UpdateUsage() error { log.Debugf("UpdateUsage: skipped with usageSwapped!=0.") return nil } + // Linux updates usage values at CONFIG_HZ. + if scanningAfter := time.Now().Sub(f.usageLast).Milliseconds(); scanningAfter < time.Second.Milliseconds()/linux.CLOCKS_PER_SEC { + log.Debugf("UpdateUsage: skipped because previous scan happened %d ms back", scanningAfter) + return nil + } f.usageLast = time.Now() err = f.updateUsageLocked(currentUsage, mincore) @@ -841,7 +859,7 @@ func (f *MemoryFile) UpdateUsage() error { // pages by invoking checkCommitted, which is a function that, for each page i // in bs, sets committed[i] to 1 if the page is committed and 0 otherwise. // -// Precondition: f.mu must be held. +// Precondition: f.mu must be held; it may be unlocked and reacquired. func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error { // Track if anything changed to elide the merge. In the common case, we // expect all segments to be committed and no merge to occur. @@ -868,7 +886,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( } else if f.usageSwapped != 0 { // We have more usage accounted for than the file itself. // That's fine, we probably caught a race where pages were - // being committed while the above loop was running. Just + // being committed while the below loop was running. Just // report the higher number that we found and ignore swap. usage.MemoryAccounting.Dec(f.usageSwapped, usage.System) f.usageSwapped = 0 @@ -880,21 +898,9 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( // Iterate over all usage data. There will only be usage segments // present when there is an associated reference. - for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - val := seg.Value() - - // Already known to be committed; ignore. - if val.knownCommitted { - continue - } - - // Assume that reclaimable pages (that aren't already known to be - // committed) are not committed. This isn't necessarily true, even - // after the reclaimer does Decommit(), because the kernel may - // subsequently back the hugepage-sized region containing the - // decommitted page with a hugepage. However, it's consistent with our - // treatment of unallocated pages, which have the same property. - if val.refs == 0 { + for seg := f.usage.FirstSegment(); seg.Ok(); { + if !seg.ValuePtr().canCommit() { + seg = seg.NextSegment() continue } @@ -917,56 +923,53 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( } // Query for new pages in core. - if err := checkCommitted(s, buf); err != nil { + // NOTE(b/165896008): mincore (which is passed as checkCommitted) + // by f.UpdateUsage() might take a really long time. So unlock f.mu + // while checkCommitted runs. + f.mu.Unlock() + err := checkCommitted(s, buf) + f.mu.Lock() + if err != nil { checkErr = err return } // Scan each page and switch out segments. - populatedRun := false - populatedRunStart := 0 - for i := 0; i <= bufLen; i++ { - // We run past the end of the slice here to - // simplify the logic and only set populated if - // we're still looking at elements. - populated := false - if i < bufLen { - populated = buf[i]&0x1 != 0 - } - - switch { - case populated == populatedRun: - // Keep the run going. - continue - case populated && !populatedRun: - // Begin the run. - populatedRun = true - populatedRunStart = i - // Keep going. + seg := f.usage.LowerBoundSegment(r.Start) + for i := 0; i < bufLen; { + if buf[i]&0x1 == 0 { + i++ continue - case !populated && populatedRun: - // Finish the run by changing this segment. - runRange := memmap.FileRange{ - Start: r.Start + uint64(populatedRunStart*usermem.PageSize), - End: r.Start + uint64(i*usermem.PageSize), + } + // Scan to the end of this committed range. + j := i + 1 + for ; j < bufLen; j++ { + if buf[j]&0x1 == 0 { + break } - seg = f.usage.Isolate(seg, runRange) - seg.ValuePtr().knownCommitted = true - // Advance the segment only if we still - // have work to do in the context of - // the original segment from the for - // loop. Otherwise, the for loop itself - // will advance the segment - // appropriately. - if runRange.End != r.End { - seg = seg.NextSegment() + } + committedFR := memmap.FileRange{ + Start: r.Start + uint64(i*usermem.PageSize), + End: r.Start + uint64(j*usermem.PageSize), + } + // Advance seg to committedFR.Start. + for seg.Ok() && seg.End() < committedFR.Start { + seg = seg.NextSegment() + } + // Mark pages overlapping committedFR as committed. + for seg.Ok() && seg.Start() < committedFR.End { + if seg.ValuePtr().canCommit() { + seg = f.usage.Isolate(seg, committedFR) + seg.ValuePtr().knownCommitted = true + amount := seg.Range().Length() + usage.MemoryAccounting.Inc(amount, seg.ValuePtr().kind) + f.usageExpected += amount + changedAny = true } - amount := runRange.Length() - usage.MemoryAccounting.Inc(amount, val.kind) - f.usageExpected += amount - changedAny = true - populatedRun = false + seg = seg.NextSegment() } + // Continue scanning for committed pages. + i = j + 1 } // Advance r.Start. @@ -978,6 +981,9 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( if err != nil { return err } + + // Continue with the first segment after r.End. + seg = f.usage.LowerBoundSegment(r.End) } return nil diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index ed5ae03d3..58f3d6fdd 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -39,6 +39,16 @@ var ( } ) +// getTLS returns the value of TPIDR_EL0 register. +// +//go:nosplit +func getTLS() (value uint64) + +// setTLS writes the TPIDR_EL0 value. +// +//go:nosplit +func setTLS(value uint64) + // bluepillArchEnter is called during bluepillEnter. // //go:nosplit @@ -51,6 +61,8 @@ func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { regs.Pstate = context.Pstate regs.Pstate &^= uint64(ring0.PsrFlagsClear) regs.Pstate |= ring0.KernelFlagsSet + regs.TPIDR_EL0 = getTLS() + return } @@ -65,6 +77,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { context.Pstate = regs.Pstate context.Pstate &^= uint64(ring0.PsrFlagsClear) context.Pstate |= ring0.UserFlagsSet + setTLS(regs.TPIDR_EL0) lazyVfp := c.GetLazyVFP() if lazyVfp != 0 { diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 04efa0147..09c7e88e5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -32,6 +32,18 @@ #define CONTEXT_PC 0x1B8 #define CONTEXT_R0 0xB8 +// getTLS returns the value of TPIDR_EL0 register. +TEXT ·getTLS(SB),NOSPLIT,$0-8 + MRS TPIDR_EL0, R1 + MOVD R1, ret+0(FP) + RET + +// setTLS writes the TPIDR_EL0 value. +TEXT ·setTLS(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + MSR R1, TPIDR_EL0 + RET + // See bluepill.go. TEXT ·bluepill(SB),NOSPLIT,$0 begin: diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 84df0f878..5831b9345 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -38,6 +38,8 @@ const ( _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080 _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082 _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600 + _KVM_ARM64_REGS_TIMER_CNT = 0x603000000013df1a + _KVM_ARM64_REGS_CNTFRQ_EL0 = 0x603000000013df00 ) // Arm64: Architectural Feature Access Control Register EL1. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 61ed24d01..f70d761fd 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -625,3 +626,35 @@ func (c *vCPU) BounceToKernel() { func (c *vCPU) BounceToHost() { c.bounce(true) } + +// setSystemTimeLegacy calibrates and sets an approximate system time. +func (c *vCPU) setSystemTimeLegacy() error { + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Try to set the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + start := uint64(ktime.Rdtsc()) + if err := c.setTSC(start + (minimum / 2)); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && current <= upperThreshold { + return nil + } + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c67127d95..a8b729e62 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -252,38 +252,6 @@ func (c *vCPU) setSystemTime() error { } } -// setSystemTimeLegacy calibrates and sets an approximate system time. -func (c *vCPU) setSystemTimeLegacy() error { - const minIterations = 10 - minimum := uint64(0) - for iter := 0; ; iter++ { - // Try to set the TSC to an estimate of where it will be - // on the host during a "fast" system call iteration. - start := uint64(ktime.Rdtsc()) - if err := c.setTSC(start + (minimum / 2)); err != nil { - return err - } - // See if this is our new minimum call time. Note that this - // serves two functions: one, we make sure that we are - // accurately predicting the offset we need to set. Second, we - // don't want to do the final set on a slow call, which could - // produce a really bad result. - end := uint64(ktime.Rdtsc()) - if end < start { - continue // Totally bogus: unstable TSC? - } - current := end - start - if current < minimum || iter == 0 { - minimum = current // Set our new minimum. - } - // Is this past minIterations and within ~10% of minimum? - upperThreshold := (((minimum << 3) + minimum) >> 3) - if iter >= minIterations && current <= upperThreshold { - return nil - } - } -} - // nonCanonical generates a canonical address return. // //go:nosplit diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index a163f956d..1344ed3c9 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -159,9 +159,33 @@ func (c *vCPU) initArchState() error { } c.floatingPointState = arch.NewFloatingPointData() + + return c.setSystemTime() +} + +// setTSC sets the counter Virtual Offset. +func (c *vCPU) setTSC(value uint64) error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + data = uint64(value) + + if err := c.setOneRegister(®); err != nil { + return err + } + return nil } +// setSystemTime sets the vCPU to the system time. +func (c *vCPU) setSystemTime() error { + return c.setSystemTimeLegacy() +} + //go:nosplit func (c *vCPU) loadSegments(tid uint64) { // TODO(gvisor.dev/issue/1238): TLS is not supported. @@ -235,8 +259,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return c.fault(int32(syscall.SIGSEGV), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt - case ring0.El0Sync_undef, - ring0.El1Sync_undef: + case ring0.El0Sync_undef: + return c.fault(int32(syscall.SIGILL), info) + case ring0.El1Sync_undef: *info = arch.SignalInfo{ Signo: int32(syscall.SIGILL), Code: 1, // ILL_ILLOPC (illegal opcode). diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2370a9276..1079a024b 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -366,6 +366,19 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// EXCEPTION_WITH_ERROR is a common exception handler function. +#define EXCEPTION_WITH_ERROR(user, vector) \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + WORD $0xd538601a; \ //MRS FAR_EL1, R26 + MOVD R26, CPU_FAULT_ADDR(RSV_REG); \ + MOVD $user, R3; \ + MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user. + MOVD $vector, R3; \ + MOVD R3, CPU_VECTOR_CODE(RSV_REG); \ + MRS ESR_EL1, R3; \ + MOVD R3, CPU_ERROR_CODE(RSV_REG); \ + B ·kernelExitToEl1(SB); + // storeAppASID writes the application's asid value. TEXT ·storeAppASID(SB),NOSPLIT,$0-8 MOVD asid+0(FP), R1 @@ -659,21 +672,7 @@ el0_svc: el0_da: el0_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $1, R3 - MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - MRS ESR_EL1, R3 - MOVD R3, CPU_ERROR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, PageFault) el0_fpsimd_acc: B ·Shutdown(SB) @@ -688,10 +687,7 @@ el0_sp_pc: B ·Shutdown(SB) el0_undef: - MOVD $El0Sync_undef, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, El0Sync_undef) el0_dbg: B ·Shutdown(SB) diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 2f1abcb0f..d91a09de1 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -53,12 +53,6 @@ func LoadFloatingPoint(*byte) // SaveFloatingPoint saves floating point state. func SaveFloatingPoint(*byte) -// GetTLS returns the value of TPIDR_EL0 register. -func GetTLS() (value uint64) - -// SetTLS writes the TPIDR_EL0 value. -func SetTLS(value uint64) - // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 8aabf7d0e..da9d3cf55 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -29,16 +29,6 @@ TEXT ·FlushTlbAll(SB),NOSPLIT,$0 ISB $15 RET -TEXT ·GetTLS(SB),NOSPLIT,$0-8 - MRS TPIDR_EL0, R1 - MOVD R1, ret+0(FP) - RET - -TEXT ·SetTLS(SB),NOSPLIT,$0-8 - MOVD addr+0(FP), R1 - MSR R1, TPIDR_EL0 - RET - TEXT ·CPACREL1(SB),NOSPLIT,$0-8 WORD $0xd5381041 // MRS CPACR_EL1, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go index 1a49f12a2..5ddd10256 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -36,7 +36,7 @@ const ( pudSize = 1 << pudShift pgdSize = 1 << pgdShift - ttbrASIDOffset = 55 + ttbrASIDOffset = 48 ttbrASIDMask = 0xff entriesPerPage = 512 diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index d9621968c..37d02948f 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -24,6 +24,8 @@ import ( ) // SCMRightsVFS2 represents a SCM_RIGHTS socket control message. +// +// +stateify savable type SCMRightsVFS2 interface { transport.RightsControlMessage @@ -34,9 +36,11 @@ type SCMRightsVFS2 interface { Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool) } -// RightsFiles represents a SCM_RIGHTS socket control message. A reference is -// maintained for each vfs.FileDescription and is release either when an FD is created or -// when the Release method is called. +// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference +// is maintained for each vfs.FileDescription and is release either when an FD +// is created or when the Release method is called. +// +// +stateify savable type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 163af329b..9a2cac40b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// +stateify savable type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &socketVFS2{ diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index faa61160e..7e7857ac3 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { + return syserror.EACCES +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return syserror.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(enabled bool) error { +func (s *Stack) SetTCPSACKEnabled(bool) error { return syserror.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { +func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return syserror.EACCES } @@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { } if rawLine == "" { - return fmt.Errorf("Failed to get raw line") + return fmt.Errorf("failed to get raw line") } parts := strings.SplitN(rawLine, ":", 2) if len(parts) != 2 { - return fmt.Errorf("Failed to get prefix from: %q", rawLine) + return fmt.Errorf("failed to get prefix from: %q", rawLine) } sliceStat = toSlice(stat) fields := strings.Fields(strings.TrimSpace(parts[1])) if len(fields) != len(sliceStat) { - return fmt.Errorf("Failed to parse fields: %q", rawLine) + return fmt.Errorf("failed to parse fields: %q", rawLine) } if _, ok := stat.(*inet.StatSNMPTCP); ok { snmpTCP = true @@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) } if err != nil { - return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err) } } @@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { } // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { +func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 844acfede..352c51390 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.TCPProtocolNumber { - return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber) } return &TCPMatcher{ diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 63201201c..c88d8268d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber) } return &UDPMatcher{ diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index e8930f031..f061c5d62 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index c84d8bd7c..22216158e 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -36,9 +36,9 @@ type commandKind int const ( kindNew commandKind = 0x0 - kindDel = 0x1 - kindGet = 0x2 - kindSet = 0x3 + kindDel commandKind = 0x1 + kindGet commandKind = 0x2 + kindSet commandKind = 0x3 ) func typeKind(typ uint16) commandKind { @@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } attrs = rest + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We add the local interface address here + // and ignore the IFA_ADDRESS. switch ahdr.Type { case linux.IFA_LOCAL: err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ @@ -439,8 +444,57 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } else if err != nil { return syserr.ErrInvalidArgument } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported + } + } + return nil +} + +// delAddr handles RTM_DELADDR requests. +func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We use the local interface address to + // remove the address and ignore the IFA_ADDRESS. + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err != nil { + return syserr.ErrInvalidArgument + } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported } } + return nil } @@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: return p.newAddr(ctx, msg, ms) + case linux.RTM_DELADDR: + return p.delAddr(ctx, msg, ms) default: return syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 5ddcd4be5..3baad098b 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -16,6 +16,7 @@ package netlink import ( + "io" "math" "gvisor.dev/gvisor/pkg/abi/linux" @@ -748,6 +749,12 @@ func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, t buf := make([]byte, src.NumBytes()) n, err := src.CopyIn(ctx, buf) + // io.EOF can be only returned if src is a file, this means that + // sendMsg is called from splice and the error has to be ignored in + // this case. + if err == io.EOF { + err = nil + } if err != nil { // Don't partially consume messages. return 0, syserr.FromError(err) diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index c83b23242..461d524e5 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -37,6 +37,8 @@ import ( // to/from the kernel. // // SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 87e30d742..86c634715 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -587,6 +587,11 @@ func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { + // EOF can be returned only if src is a file and this means it + // is in a splice syscall and the error has to be ignored. + if err == io.EOF { + return v, nil + } return nil, tcpip.ErrBadAddress } return v, nil @@ -1239,6 +1244,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ACCEPTCONN: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 4c6791fff..b0d9e4d9e 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -35,6 +35,8 @@ import ( // SocketVFS2 encapsulates all the state needed to represent a network stack // endpoint in the kernel context. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -55,7 +57,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu } mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &SocketVFS2{ diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 1028d2a6e..fa9ac9059 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } -// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +// convertAddr converts an InterfaceAddr to a ProtocolAddress. +func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) { var ( - protocol tcpip.NetworkProtocolNumber - address tcpip.Address + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + protocolAddress tcpip.ProtocolAddress ) switch addr.Family { case linux.AF_INET: - if len(addr.Addr) < header.IPv4AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv4AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv4AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv4.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) - + address = tcpip.Address(addr.Addr) case linux.AF_INET6: - if len(addr.Addr) < header.IPv6AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv6AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv6AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv6.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) - + address = tcpip.Address(addr.Addr) default: - return syserror.ENOTSUP + return protocolAddress, syserror.ENOTSUP } - protocolAddress := tcpip.ProtocolAddress{ + protocolAddress = tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: int(addr.PrefixLen), }, } + return protocolAddress, nil +} + +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } // Attach address to interface. - if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + nicID := tcpip.NICID(idx) + if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network if it doesn't exist already. + localRoute := tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: nicID, + } + + for _, rt := range s.Stack.GetRouteTable() { + if rt.Equal(localRoute) { + return nil + } + } + + // Local route does not exist yet. Add it. + s.Stack.AddRoute(localRoute) + + return nil +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } + + // Remove addresses matching the address and prefix. + nicID := tcpip.NICID(idx) + if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil { return syserr.TranslateNetstackError(err).ToError() } - // Add route for local network. - s.Stack.AddRoute(tcpip.Route{ + // Remove the corresponding local network route if it exists. + localRoute := tcpip.Route{ Destination: protocolAddress.AddressWithPrefix.Subnet(), Gateway: "", // No gateway for local network. - NIC: tcpip.NICID(idx), + NIC: nicID, + } + s.Stack.RemoveRoutes(func(rt tcpip.Route) bool { + return rt.Equal(localRoute) }) + return nil } diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cc7408698..cce0acc33 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "socket_refs.go", package = "unix", prefix = "socketOperations", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketOperations", }, @@ -19,7 +19,7 @@ go_template_instance( out = "socket_vfs2_refs.go", package = "unix", prefix = "socketVFS2", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketVFS2", }, @@ -43,6 +43,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 26c3a51b9..3ebbd28b0 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "queue_refs.go", package = "transport", prefix = "queue", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "queue", }, @@ -44,6 +44,7 @@ go_library( "//pkg/ilist", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index d6fc03520..b648273a4 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -32,6 +32,8 @@ import ( const initialLimit = 16 * 1024 // A RightsControlMessage is a control message containing FDs. +// +// +stateify savable type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. Clone() RightsControlMessage @@ -336,7 +338,7 @@ type Receiver interface { RecvMaxQueueSize() int64 // Release releases any resources owned by the Receiver. It should be - // called before droping all references to a Receiver. + // called before dropping all references to a Receiver. Release(ctx context.Context) } @@ -487,7 +489,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds c := q.control.Clone() // Don't consume data since we are peeking. - copied, data, _ = vecCopy(data, q.buffer) + copied, _, _ = vecCopy(data, q.buffer) return copied, copied, c, false, q.addr, notify, nil } @@ -572,6 +574,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds return copied, copied, c, cmTruncated, q.addr, notify, nil } +// Release implements Receiver.Release. +func (q *streamQueueReceiver) Release(ctx context.Context) { + q.queueReceiver.Release(ctx) + q.control.Release(ctx) +} + // A ConnectedEndpoint is an Endpoint that can be used to send Messages. type ConnectedEndpoint interface { // Passcred implements Endpoint.Passcred. @@ -619,7 +627,7 @@ type ConnectedEndpoint interface { SendMaxQueueSize() int64 // Release releases any resources owned by the ConnectedEndpoint. It should - // be called before droping all references to a ConnectedEndpoint. + // be called before dropping all references to a ConnectedEndpoint. Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to @@ -879,7 +887,7 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.PasscredOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index a4a76d0a3..adad485a9 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -81,7 +81,6 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty }, } s.EnableLeakCheck() - return fs.NewFile(ctx, d, flags, &s) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 678355fb9..7a78444dc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // returns a corresponding file description. func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) @@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } + sock.EnableLeakCheck() sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 0ea4aab8b..563d60578 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -12,10 +12,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/time", + "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 245d2c5cf..167754537 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -19,10 +19,12 @@ import ( "fmt" "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -57,7 +59,7 @@ type SaveOpts struct { } // Save saves the system state. -func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { +func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error { log.Infof("Sandbox save started, pausing all tasks.") k.Pause() k.ReceiveTaskStates() @@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { err = ErrStateFile{err} } else { // Save the kernel. - err = k.SaveTo(wc) + err = k.SaveTo(ctx, wc) // ENOSPC is a state file error. This error can only come from // writing the state file, and not from fs.FileOperations.Fsync @@ -108,7 +110,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { +func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n, clocks) + return k.LoadFrom(ctx, r, n, clocks, vfsOpts) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 9c9def7cd..36902d177 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 47dadb800..c2d4bf805 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -129,9 +129,17 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getPID(t, id, num) return uintptr(v), nil, err + case linux.IPC_STAT: + arg := args[3].Pointer() + ds, err := ipcStat(t, id) + if err == nil { + _, err = ds.CopyOut(t, arg) + } + + return 0, nil, err + case linux.IPC_INFO, linux.SEM_INFO, - linux.IPC_STAT, linux.SEM_STAT, linux.SEM_STAT_ANY, linux.GETNCNT, @@ -171,6 +179,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP return set.Change(t, creds, owner, perms) } +func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.GetStat(creds) +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go index 6320593f0..db3d924d9 100644 --- a/pkg/sentry/syscalls/linux/sys_sysinfo.go +++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" ) -// Sysinfo implements the sysinfo syscall as described in man 2 sysinfo. +// Sysinfo implements Linux syscall sysinfo(2). func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index d8b8d9783..36e89700e 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -145,16 +145,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(file.StatusFlags()), nil, nil case linux.F_SETFL: return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint()) - case linux.F_SETPIPE_SZ: - pipefile, ok := file.Impl().(*pipe.VFSPipeFD) - if !ok { - return 0, nil, syserror.EBADF - } - n, err := pipefile.SetPipeSize(int64(args[2].Int())) - if err != nil { - return 0, nil, err - } - return uintptr(n), nil, nil case linux.F_GETOWN: owner, hasOwner := getAsyncOwner(t, file) if !hasOwner { @@ -190,6 +180,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err } return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) + case linux.F_SETPIPE_SZ: + pipefile, ok := file.Impl().(*pipe.VFSPipeFD) + if !ok { + return 0, nil, syserror.EBADF + } + n, err := pipefile.SetPipeSize(int64(args[2].Int())) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil case linux.F_GETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index bf5c1171f..035e2a6b0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -45,6 +45,9 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if count > int64(kernel.MAX_RW_COUNT) { count = int64(kernel.MAX_RW_COUNT) } + if count < 0 { + return 0, nil, syserror.EINVAL + } // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { @@ -192,6 +195,9 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if count > int64(kernel.MAX_RW_COUNT) { count = int64(kernel.MAX_RW_COUNT) } + if count < 0 { + return 0, nil, syserror.EINVAL + } // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index ab1d140d2..5ed6726ab 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -278,7 +278,7 @@ func TotalMemory(memSize, used uint64) uint64 { } if memSize < used { memSize = used - // Bump totalSize to the next largest power of 2, if one exists, so + // Bump memSize to the next largest power of 2, if one exists, so // that MemFree isn't 0. if msb := bits.MostSignificantOne64(memSize); msb < 63 { memSize = uint64(1) << (uint(msb) + 1) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index c855608db..440c9307c 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "file_description_refs.go", package = "vfs", prefix = "FileDescription", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FileDescription", }, @@ -43,7 +43,7 @@ go_template_instance( out = "mount_namespace_refs.go", package = "vfs", prefix = "MountNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "MountNamespace", }, @@ -54,7 +54,7 @@ go_template_instance( out = "filesystem_refs.go", package = "vfs", prefix = "Filesystem", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Filesystem", }, @@ -87,6 +87,7 @@ go_library( "pathname.go", "permissions.go", "resolving_path.go", + "save_restore.go", "vfs.go", ], visibility = ["//pkg/sentry:internal"], @@ -99,6 +100,7 @@ go_library( "//pkg/gohacks", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 8f36c3e3b..a98aac52b 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -74,7 +74,7 @@ type epollInterestKey struct { // +stateify savable type epollInterest struct { // epoll is the owning EpollInstance. epoll is immutable. - epoll *EpollInstance + epoll *EpollInstance `state:"wait"` // key is the file to which this epollInterest applies. key is immutable. key epollInterestKey diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 183957ad8..546e445aa 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -183,7 +183,6 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 2d27d9d35..ba6e6ed49 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { return vfs.PrependPathAtVFSRootError{} } - if &d.vfsd == mnt.Root() { + if mnt != nil && &d.vfsd == mnt.Root() { return nil } if d.parent == nil { @@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath d = d.parent } } + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func DebugPathname(d *Dentry) string { + var b fspath.Builder + _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 55783d4eb..1ff202f2a 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package lock provides POSIX and BSD style file locking for VFS2 file -// implementations. -// -// The actual implementations can be found in the lock package under -// sentry/fs/lock. package vfs import ( diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 78f115bfa..d452d2cda 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) @@ -106,6 +107,9 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount if opts.ReadOnly { mnt.setReadOnlyLocked(true) } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(mnt, "vfs.Mount") + } return mnt } @@ -489,26 +493,38 @@ func (mnt *Mount) IncRef() { // DecRef decrements mnt's reference count. func (mnt *Mount) DecRef(ctx context.Context) { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - var vd VirtualDentry - if mnt.parent() != nil { - mnt.vfs.mountMu.Lock() - mnt.vfs.mounts.seq.BeginWrite() - vd = mnt.vfs.disconnectLocked(mnt) - mnt.vfs.mounts.seq.EndWrite() - mnt.vfs.mountMu.Unlock() - } - if mnt.root != nil { - mnt.root.DecRef(ctx) - } - mnt.fs.DecRef(ctx) - if vd.Ok() { - vd.DecRef(ctx) + r := atomic.AddInt64(&mnt.refs, -1) + if r&^math.MinInt64 == 0 { // mask out MSB + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(mnt, "vfs.Mount") } + mnt.destroy(ctx) } } +func (mnt *Mount) destroy(ctx context.Context) { + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() + } + if mnt.root != nil { + mnt.root.DecRef(ctx) + } + mnt.fs.DecRef(ctx) + if vd.Ok() { + vd.DecRef(ctx) + } +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (mnt *Mount) LeakMessage() string { + return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs)) +} + // DecRef decrements mntns' reference count. func (mntns *MountNamespace) DecRef(ctx context.Context) { vfs := mntns.root.fs.VirtualFilesystem() diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b7d122d22..cb48c37a1 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -98,7 +98,6 @@ type mountTable struct { // length and cap in separate uint32s) for ~free. size uint64 - // FIXME(gvisor.dev/issue/1663): Slots need to be saved. slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init } @@ -212,6 +211,26 @@ loop: } } +// Range calls f on each Mount in mt. If f returns false, Range stops iteration +// and returns immediately. +func (mt *mountTable) Range(f func(*Mount) bool) { + tcap := uintptr(1) << (mt.size & mtSizeOrderMask) + slotPtr := mt.slots + last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes)) + for { + slot := (*mountSlot)(slotPtr) + if slot.value != nil { + if !f((*Mount)(slot.value)) { + return + } + } + if slotPtr == last { + return + } + slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes) + } +} + // Insert inserts the given mount into mt. // // Preconditions: mt must not already contain a Mount with the same mount point diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go new file mode 100644 index 000000000..46e50d55d --- /dev/null +++ b/pkg/sentry/vfs/save_restore.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// FilesystemImplSaveRestoreExtension is an optional extension to +// FilesystemImpl. +type FilesystemImplSaveRestoreExtension interface { + // PrepareSave prepares this filesystem for serialization. + PrepareSave(ctx context.Context) error + + // CompleteRestore completes restoration from checkpoint for this + // filesystem after deserialization. + CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error +} + +// PrepareSave prepares all filesystems for serialization. +func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.PrepareSave(ctx); err != nil { + ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to prepare for serialization", failures) + } + return nil +} + +// CompleteRestore completes restoration from checkpoint for all filesystems +// after deserialization. +func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.CompleteRestore(ctx, *opts); err != nil { + ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures) + } + return nil +} + +// CompleteRestoreOptions contains options to +// VirtualFilesystem.CompleteRestore() and +// FilesystemImplSaveRestoreExtension.CompleteRestore(). +type CompleteRestoreOptions struct { + // If ValidateFileSizes is true, filesystem implementations backed by + // remote filesystems should verify that file sizes have not changed + // between checkpoint and restore. + ValidateFileSizes bool + + // If ValidateFileModificationTimestamps is true, filesystem + // implementations backed by remote filesystems should validate that file + // mtimes have not changed between checkpoint and restore. + ValidateFileModificationTimestamps bool +} + +// saveMounts is called by stateify. +func (vfs *VirtualFilesystem) saveMounts() []*Mount { + if atomic.LoadPointer(&vfs.mounts.slots) == nil { + // vfs.Init() was never called. + return nil + } + var mounts []*Mount + vfs.mounts.Range(func(mount *Mount) bool { + mounts = append(mounts, mount) + return true + }) + return mounts +} + +// loadMounts is called by stateify. +func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { + if mounts == nil { + return + } + vfs.mounts.Init() + for _, mount := range mounts { + vfs.mounts.Insert(mount) + } +} + +func (mnt *Mount) afterLoad() { + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&mnt.refs) != 0 { + refsvfs2.Register(mnt, "vfs.Mount") + } +} + +// afterLoad is called by stateify. +func (epi *epollInterest) afterLoad() { + // Mark all epollInterests as ready after restore so that the next call to + // EpollInstance.ReadEvents() rechecks their readiness. + epi.Callback(nil) +} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 38d2701d2..48d6252f7 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -71,7 +71,7 @@ type VirtualFilesystem struct { // points. // // mounts is analogous to Linux's mount_hashtable. - mounts mountTable + mounts mountTable `state:".([]*Mount)"` // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. @@ -780,23 +780,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre // SyncAllFilesystems has the semantics of Linux's sync(2). func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + var retErr error + for fs := range vfs.getFilesystems() { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef(ctx) + } + return retErr +} + +func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} { fss := make(map[*Filesystem]struct{}) vfs.filesystemsMu.Lock() + defer vfs.filesystemsMu.Unlock() for fs := range vfs.filesystems { if !fs.TryIncRef() { continue } fss[fs] = struct{}{} } - vfs.filesystemsMu.Unlock() - var retErr error - for fs := range fss { - if err := fs.impl.Sync(ctx); err != nil && retErr == nil { - retErr = err - } - fs.DecRef(ctx) - } - return retErr + return fss } // MkdirAllAt recursively creates non-existent directories on the given path diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD index ba2ed1ea7..abb8c3be3 100644 --- a/pkg/shim/v2/runtimeoptions/BUILD +++ b/pkg/shim/v2/runtimeoptions/BUILD @@ -11,12 +11,12 @@ proto_library( go_library( name = "runtimeoptions", - srcs = ["runtimeoptions.go"], - visibility = ["//pkg/shim/v2:__pkg__"], - deps = [ - ":api_go_proto", - "@com_github_gogo_protobuf//proto:go_default_library", + srcs = [ + "runtimeoptions.go", + "runtimeoptions_cri.go", ], + visibility = ["//pkg/shim/v2:__pkg__"], + deps = ["@com_github_gogo_protobuf//proto:go_default_library"], ) go_test( @@ -27,6 +27,6 @@ go_test( deps = [ "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", "@com_github_containerd_typeurl//:go_default_library", - "@com_github_golang_protobuf//proto:go_default_library", + "@com_github_gogo_protobuf//proto:go_default_library", ], ) diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go index aaf17b87a..072dd87f0 100644 --- a/pkg/shim/v2/runtimeoptions/runtimeoptions.go +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go @@ -13,18 +13,5 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package runtimeoptions contains the runtimeoptions proto. package runtimeoptions - -import ( - proto "github.com/gogo/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto" -) - -type Options = pb.Options - -func init() { - // The generated proto file auto registers with "golang/protobuf/proto" - // package. However, typeurl uses "golang/gogo/protobuf/proto". So registers - // the type there too. - proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") -} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go new file mode 100644 index 000000000..e6102b4cf --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go @@ -0,0 +1,383 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + "fmt" + "io" + "reflect" + "strings" + + proto "github.com/gogo/protobuf/proto" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +type Options struct { + // TypeUrl specifies the type of the content inside the config file. + TypeUrl string `protobuf:"bytes,1,opt,name=type_url,json=typeUrl,proto3" json:"type_url,omitempty"` + // ConfigPath specifies the filesystem location of the config file + // used by the runtime. + ConfigPath string `protobuf:"bytes,2,opt,name=config_path,json=configPath,proto3" json:"config_path,omitempty"` +} + +func (m *Options) Reset() { *m = Options{} } +func (*Options) ProtoMessage() {} +func (*Options) Descriptor() ([]byte, []int) { return fileDescriptorApi, []int{0} } + +func (m *Options) GetTypeUrl() string { + if m != nil { + return m.TypeUrl + } + return "" +} + +func (m *Options) GetConfigPath() string { + if m != nil { + return m.ConfigPath + } + return "" +} + +func init() { + proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") +} + +func (m *Options) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Options) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.TypeUrl) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintApi(dAtA, i, uint64(len(m.TypeUrl))) + i += copy(dAtA[i:], m.TypeUrl) + } + if len(m.ConfigPath) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintApi(dAtA, i, uint64(len(m.ConfigPath))) + i += copy(dAtA[i:], m.ConfigPath) + } + return i, nil +} + +func encodeVarintApi(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} + +func (m *Options) Size() (n int) { + var l int + _ = l + l = len(m.TypeUrl) + if l > 0 { + n += 1 + l + sovApi(uint64(l)) + } + l = len(m.ConfigPath) + if l > 0 { + n += 1 + l + sovApi(uint64(l)) + } + return n +} + +func sovApi(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} + +func sozApi(x uint64) (n int) { + return sovApi(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +func (this *Options) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&Options{`, + `TypeUrl:` + fmt.Sprintf("%v", this.TypeUrl) + `,`, + `ConfigPath:` + fmt.Sprintf("%v", this.ConfigPath) + `,`, + `}`, + }, "") + return s +} + +func valueToStringApi(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("*%v", pv) +} + +func (m *Options) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Options: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Options: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field TypeUrl", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthApi + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.TypeUrl = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ConfigPath", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthApi + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ConfigPath = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipApi(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthApi + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +func skipApi(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthApi + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipApi(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthApi = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowApi = fmt.Errorf("proto: integer overflow") +) + +func init() { proto.RegisterFile("api.proto", fileDescriptorApi) } + +var fileDescriptorApi = []byte{ + // 183 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4c, 0x2c, 0xc8, 0xd4, + 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x4d, 0x2e, 0xca, 0xd4, 0x2b, 0x2a, 0xcd, 0x2b, 0xc9, + 0xcc, 0x4d, 0xcd, 0x2f, 0x28, 0xc9, 0xcc, 0xcf, 0x2b, 0xd6, 0x2b, 0x33, 0x94, 0xd2, 0x4d, 0xcf, + 0x2c, 0xc9, 0x28, 0x4d, 0xd2, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xcf, 0x4f, 0xcf, 0xd7, 0x07, 0xab, + 0x4e, 0x2a, 0x4d, 0x03, 0xf3, 0xc0, 0x1c, 0x30, 0x0b, 0x62, 0x8a, 0x92, 0x2b, 0x17, 0xbb, 0x3f, + 0x44, 0xb3, 0x90, 0x24, 0x17, 0x47, 0x49, 0x65, 0x41, 0x6a, 0x7c, 0x69, 0x51, 0x8e, 0x04, 0xa3, + 0x02, 0xa3, 0x06, 0x67, 0x10, 0x3b, 0x88, 0x1f, 0x5a, 0x94, 0x23, 0x24, 0xcf, 0xc5, 0x9d, 0x9c, + 0x9f, 0x97, 0x96, 0x99, 0x1e, 0x5f, 0x90, 0x58, 0x92, 0x21, 0xc1, 0x04, 0x96, 0xe5, 0x82, 0x08, + 0x05, 0x24, 0x96, 0x64, 0x38, 0xc9, 0x9c, 0x78, 0x28, 0xc7, 0x78, 0xe3, 0xa1, 0x1c, 0x43, 0xc3, + 0x23, 0x39, 0xc6, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, 0x48, 0x8e, 0x71, + 0xc2, 0x63, 0x39, 0x86, 0x24, 0x36, 0xb0, 0x5d, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x07, + 0x00, 0xf2, 0x18, 0xbe, 0x00, 0x00, 0x00, +} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go index f4c238a00..c59a2400e 100644 --- a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go @@ -15,11 +15,12 @@ package runtimeoptions import ( + "bytes" "testing" shim "github.com/containerd/containerd/runtime/v1/shim/v1" "github.com/containerd/typeurl" - "github.com/golang/protobuf/proto" + "github.com/gogo/protobuf/proto" ) func TestCreateTaskRequest(t *testing.T) { @@ -32,7 +33,11 @@ func TestCreateTaskRequest(t *testing.T) { if err := proto.UnmarshalText(encodedText, got); err != nil { t.Fatalf("unable to unmarshal text: %v", err) } - t.Logf("got: %s", proto.MarshalTextString(got)) + var textBuffer bytes.Buffer + if err := proto.MarshalText(&textBuffer, got); err != nil { + t.Errorf("unable to marshal text: %v", err) + } + t.Logf("got: %s", string(textBuffer.Bytes())) // Check the options. wantOptions := &Options{} diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 089b3bbef..92c51879b 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "pending_list", - out = "pending_list.go", - package = "state", - prefix = "pending", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*objectEncodeState", - "ElementMapper": "pendingMapper", - "Linker": "*pendingEntry", - }, -) - -go_template_instance( name = "deferred_list", out = "deferred_list.go", package = "state", @@ -83,7 +70,6 @@ go_library( "deferred_list.go", "encode.go", "encode_unsafe.go", - "pending_list.go", "state.go", "state_norace.go", "state_race.go", diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 89467ca8e..e519ddeca 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -21,6 +21,7 @@ import ( "math" "reflect" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c // For the purposes of this function, a child object is either a field within a // struct or an array element, with one such indirection per element in // path. The returned value may be an unexported field, so it may not be -// directly assignable. See unsafePointerTo. +// directly assignable. See decode_unsafe.go. func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { // See wire.Ref.Dots. The path here is specified in reverse order. for i := len(path) - 1; i >= 0; i-- { @@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // Normal assignment: authoritative only if no dots. v := ds.register(x, obj.Type().Elem()) - if v.IsValid() { - obj.Set(unsafePointerTo(v)) - } + obj.Set(reflectValueRWAddr(v)) case wire.Bool: obj.SetBool(bool(x)) case wire.Int: @@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // contents will still be filled in later on. typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. v := ds.register(&x.Ref, typ) - obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity))) case *wire.Array: ds.decodeArray(ods, obj, x) case *wire.Struct: @@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) { ds.pending.PushBack(rootOds) // Read the number of objects. - lastID, object, err := ReadHeader(ds.r) + numObjects, object, err := ReadHeader(ds.r) if err != nil { Failf("header error: %w", err) } @@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) { var ( encoded wire.Object ods *objectDecodeState - id = objectID(1) + id objectID tid = typeID(1) ) if err := safely(func() { // Decode all objects in the stream. // - // Note that the structure of this decoding loop should match - // the raw decoding loop in printer.go. - for id <= objectID(lastID) { - // Unmarshal the object. + // Note that the structure of this decoding loop should match the raw + // decoding loop in state/pretty/pretty.printer.printStream(). + for i := uint64(0); i < numObjects; { + // Unmarshal either a type object or object ID. encoded = wire.Load(ds.r) - - // Is this a type object? Handle inline. - if wt, ok := encoded.(*wire.Type); ok { - ds.types.Register(wt) + switch we := encoded.(type) { + case *wire.Type: + ds.types.Register(we) tid++ encoded = nil continue + case wire.Uint: + id = objectID(we) + i++ + // Unmarshal and resolve the actual object. + encoded = wire.Load(ds.r) + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } + // For error handling. + ods = nil + encoded = nil + default: + Failf("wanted type or object ID, got %#v", encoded) } - - // Actually resolve the object. - ods = ds.lookup(id) - if ods != nil { - // Decode the object. - ds.decodeObject(ods, ods.obj, encoded) - } else { - // If an object hasn't had interest registered - // previously or isn't yet valid, we deferred - // decoding until interest is registered. - ds.deferred[id] = encoded - } - - // For error handling. - ods = nil - encoded = nil - id++ } }); err != nil { // Include as much information as we can, taking into account @@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) { if ods != nil { Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) } else if encoded != nil { - Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) + Failf("error decoding from %#v: %w", encoded, err) } else { Failf("general decoding error: %w", err) } } // Check if we have any deferred objects. + numDeferred := 0 for id, encoded := range ds.deferred { - // Shoud never happen, the graph was bogus. - Failf("still have deferred objects: one is ID %d, %#v", id, encoded) + numDeferred++ + if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 { + typ := ds.types.LookupType(typeID(s.TypeID)) + log.Warningf("unused deferred object: ID %d, type %v", id, typ) + } else { + log.Warningf("unused deferred object: ID %d, %#v", id, encoded) + } + } + if numDeferred != 0 { + Failf("still had %d deferred objects", numDeferred) } // Scan and fire all callbacks. We iterate over the list of incomplete diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go index d048f61a1..f1208e2a2 100644 --- a/pkg/state/decode_unsafe.go +++ b/pkg/state/decode_unsafe.go @@ -15,13 +15,62 @@ package state import ( + "fmt" "reflect" + "runtime" "unsafe" ) -// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on -// values representing unexported fields. This bypasses visibility, but not -// type safety. -func unsafePointerTo(obj reflect.Value) reflect.Value { +// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned +// reflect.Value is usable in assignments even if obj was obtained by the use +// of unexported struct fields. +// +// Preconditions: obj.CanAddr(). +func reflectValueRWAddr(obj reflect.Value) reflect.Value { return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) } + +// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the +// returned reflect.Value is usable in assignments even if obj was obtained by +// the use of unexported struct fields. +// +// Preconditions: +// * arr.Kind() == reflect.Array. +// * i, j, k >= 0. +// * i <= j <= k <= arr.Len(). +func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value { + if arr.Kind() != reflect.Array { + panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array)) + } + if i < 0 || j < 0 || k < 0 { + panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k)) + } + if i > j { + panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j)) + } + if j > k { + panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k)) + } + if k > arr.Len() { + panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len())) + } + + sliceTyp := reflect.SliceOf(arr.Type().Elem()) + if i == arr.Len() { + // By precondition, i == j == k == arr.Len(). + return reflect.MakeSlice(sliceTyp, 0, 0) + } + slh := reflect.SliceHeader{ + // reflect.Value.CanAddr() == false for arrays, so we need to get the + // address from the first element of the array. + Data: arr.Index(i).UnsafeAddr(), + Len: j - i, + Cap: k - i, + } + slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem() + // Before slobj is constructed, arr holds the only pointer-typed pointer to + // the array since reflect.SliceHeader.Data is a uintptr, so arr must be + // kept alive. + runtime.KeepAlive(arr) + return slobj +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 92fcad4e9..560e7c2a3 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -17,13 +17,14 @@ package state import ( "context" "reflect" + "sort" "gvisor.dev/gvisor/pkg/state/wire" ) // objectEncodeState the type and identity of an object occupying a memory // address range. This is the value type for addrSet, and the intrusive entry -// for the pending and deferred lists. +// for the deferred list. type objectEncodeState struct { // id is the assigned ID for this object. id objectID @@ -47,7 +48,6 @@ type objectEncodeState struct { // references may be updated directly and automatically. refs []*wire.Ref - pendingEntry deferredEntry } @@ -93,9 +93,15 @@ type encodeState struct { // serialized. pendingTypes []wire.Type - // pending is the list of objects to be serialized. Serialization does + // pending maps object IDs to objects to be serialized. Serialization does // not actually occur until the full object graph is computed. - pending pendingList + pending map[objectID]*objectEncodeState + + // encodedStructs maps reflect.Values representing structs to previous + // encodings of those structs. This is necessary to avoid duplicate calls + // to SaverLoader.StateSave() that may result in multiple calls to + // Sink.SaveValue() for a given field, resulting in object duplication. + encodedStructs map[reflect.Value]*wire.Struct // stats tracks time data. stats Stats @@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { // depending on this value knows there's nothing there. return } - if seg, _ := es.values.Find(addr); seg.Ok() { + seg, gap := es.values.Find(addr) + if seg.Ok() { // Ensure the map types match. existing := seg.Value() if existing.obj.Type() != obj.Type() { @@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { } // Record the map. + r := addrRange{addr, addr + 1} oes := &objectEncodeState{ id: es.nextID(), obj: obj, how: encodeMapAsValue, } - es.values.Add(addrRange{addr, addr + 1}, oes) - es.pending.PushBack(oes) + // Use Insert instead of InsertWithoutMergingUnchecked when race + // detection is enabled to get additional sanity-checking from Merge. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes es.deferred.PushBack(oes) // See above: no ref recording. @@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { obj: obj, } es.zeroValues[typ] = oes - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) } @@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { size = 1 // See above. } - // Calculate the container. end := addr + size r := addrRange{addr, end} - if seg, _ := es.values.Find(addr); seg.Ok() { + seg := es.values.LowerBoundSegment(addr) + var ( + oes *objectEncodeState + gap addrGapIterator + ) + + // Does at least one previously-registered object overlap this one? + if seg.Ok() && seg.Start() < end { existing := seg.Value() - switch { - case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): - // The object is a perfect match. Happy path. Avoid the - // traversal and just return directly. We don't need to - // encode the type information or any dots here. + + if seg.Range() == r && typ == existing.obj.Type() { + // This exact object is already registered. Avoid the traversal and + // just return directly. We don't need to encode the type + // information or any dots here. ref.Root = wire.Uint(existing.id) existing.refs = append(existing.refs, ref) return + } - case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): - // The previously registered object is larger than - // this, no need to update. But we expect some - // traversal below. + if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) { + // This object is contained within a previously-registered object. + // Perform traversal from the container to the new object. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr) + ref.Type = es.findType(existing.obj.Type()) + existing.refs = append(existing.refs, ref) + return + } - case seg.Start() == addr && seg.End() == end: - if !isSameSizeParent(obj, existing.obj.Type()) { - break // Needs traversal. + // This object contains one or more previously-registered objects. + // Remove them and update existing references to use the new one. + oes := &objectEncodeState{ + // Reuse the root ID of the first contained element. + id: existing.id, + obj: obj, + } + type elementEncodeState struct { + addr uintptr + typ reflect.Type + refs []*wire.Ref + } + var ( + elems []elementEncodeState + gap addrGapIterator + ) + for { + // Each contained object should be completely contained within + // this one. + if raceEnabled && !r.IsSupersetOf(seg.Range()) { + Failf("containing object %#v does not contain existing object %#v", obj, existing.obj) } - fallthrough // Needs update. - - case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): - // Update the object and redo the encoding. - old := existing.obj - existing.obj = obj + elems = append(elems, elementEncodeState{ + addr: seg.Start(), + typ: existing.obj.Type(), + refs: existing.refs, + }) + delete(es.pending, existing.id) es.deferred.Remove(existing) - es.deferred.PushBack(existing) - - // The previously registered object is superseded by - // this new object. We are guaranteed to not have any - // mergeable neighbours in this segment set. - if !raceEnabled { - seg.SetRangeUnchecked(r) - } else { - // Add extra paranoid. This will be statically - // removed at compile time unless a race build. - es.values.Remove(seg) - es.values.Add(r, existing) - seg = es.values.LowerBoundSegment(addr) + gap = es.values.Remove(seg) + seg = gap.NextSegment() + if !seg.Ok() || seg.Start() >= end { + break } - - // Compute the traversal required & update references. - dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) - wt := es.findType(obj.Type()) - for _, ref := range existing.refs { + existing = seg.Value() + } + wt := es.findType(typ) + for _, elem := range elems { + dots := traverse(typ, elem.typ, addr, elem.addr) + for _, ref := range elem.refs { + ref.Root = wire.Uint(oes.id) ref.Dots = append(ref.Dots, dots...) ref.Type = wt } - default: - // There is a non-sensical overlap. - Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + oes.refs = append(oes.refs, elem.refs...) } - - // Compute the new reference, record and return it. - ref.Root = wire.Uint(existing.id) - ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) - ref.Type = es.findType(obj.Type()) - existing.refs = append(existing.refs, ref) + // Finally register the new containing object. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) return } - // The only remaining case is a pointer value that doesn't overlap with - // any registered addresses. Create a new entry for it, and start - // tracking the first reference we just created. - oes := &objectEncodeState{ + // No existing object overlaps this one. Register a new object. + oes = &objectEncodeState{ id: es.nextID(), obj: obj, } + if seg.Ok() { + gap = seg.PrevGap() + } else { + gap = es.values.LastGap() + } if !raceEnabled { - es.values.AddWithoutMerging(r, oes) + es.values.InsertWithoutMergingUnchecked(gap, r, oes) } else { - // Merges should never happen. This is just enabled extra - // sanity checks because the Merge function below will panic. - es.values.Add(r, oes) + es.values.Insert(gap, r, oes) } - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) ref.Root = wire.Uint(oes.id) oes.refs = append(oes.refs, ref) @@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) { // encodeStruct encodes a composite object. func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + if s, ok := es.encodedStructs[obj]; ok { + *dest = s + return + } + s := &wire.Struct{} + *dest = s + es.encodedStructs[obj] = s + // Ensure that the obj is addressable. There are two cases when it is // not. First, is when this is dispatched via SaveValue. Second, when // this is a map key as a struct. Either way, we need to make a copy to @@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { obj = localObj.Elem() } - // Prepare the value. - s := &wire.Struct{} - *dest = s - // Look the type up in the database. te, ok := es.types.Lookup(obj.Type()) if te == nil { @@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) { Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) } - // Check that items are pending. - if es.pending.Front() == nil { + // Check that we have objects to serialize. + if len(es.pending) == 0 { Failf("pending is empty?") } - // Write the header with the number of objects. Note that there is no - // way that es.lastID could conflict with objectID, which would - // indicate that an impossibly large encoding. - if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + // Write the header with the number of objects. + if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil { Failf("error writing header: %w", err) } // Serialize all pending types and pending objects. Note that we don't // bother removing from this list as we walk it because that just // wastes time. It will not change after this point. - var id objectID if err := safely(func() { for _, wt := range es.pendingTypes { // Encode the type. wire.Save(es.w, &wt) } - for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { - id++ // First object is 1. - if oes.id != id { - Failf("expected id %d, got %d", id, oes.id) - } - - // Marshall the object. + // Emit objects in ID order. + ids := make([]objectID, 0, len(es.pending)) + for id := range es.pending { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + for _, id := range ids { + // Encode the id. + wire.Save(es.w, wire.Uint(id)) + // Marshal the object. + oes := es.pending[id] wire.Save(es.w, oes.encoded) } }); err != nil { // Include the object and the error. Failf("error serializing object %#v: %w", oes.encoded, err) } - - // Check what we wrote. - if id != es.lastID { - Failf("expected %d objects, wrote %d", es.lastID, id) - } } // objectFlag indicates that the length is a # of objects, rather than a raw @@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error { }) } -// pendingMapper is for the pending list. -type pendingMapper struct{} - -func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // deferredMapper is for the deferred list. type deferredMapper struct{} diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go index 887f453a9..c6e8bb31d 100644 --- a/pkg/state/pretty/pretty.go +++ b/pkg/state/pretty/pretty.go @@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { buf.WriteString(typ) buf.WriteString(")(") buf.WriteString(baseRef) + buf.WriteString(")") for _, component := range x.Dots { switch v := component.(type) { case *wire.FieldName: @@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) } } - buf.WriteString(")") fullRef = buf.String() } if p.html { @@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { // Note that this loop must match the general structure of the // loop in decode.go. But we don't register type information, // etc. and just print the raw structures. + type objectAndID struct { + id uint64 + obj wire.Object + } var ( tid uint64 = 1 - objects []wire.Object + objects []objectAndID ) - for oid := uint64(1); oid <= length; { - // Unmarshal the object. + for i := uint64(0); i < length; { + // Unmarshal either a type object or object ID. encoded := wire.Load(r) - - // Is this a type? - if typ, ok := encoded.(*wire.Type); ok { + switch we := encoded.(type) { + case *wire.Type: str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dt%d", graph, tid) - p.typeSpecs[tag] = typ + p.typeSpecs[tag] = we if p.html { // See below. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) @@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { return err } tid++ - continue + case wire.Uint: + // Unmarshal the actual object. + objects = append(objects, objectAndID{ + id: uint64(we), + obj: wire.Load(r), + }) + i++ + default: + return fmt.Errorf("wanted type or object ID, got %#v", encoded) } - - // Otherwise, it is a node. - objects = append(objects, encoded) - oid++ } - for i, encoded := range objects { - // oid starts at 1. - oid := i + 1 + for _, objAndID := range objects { // Format the node. - str, _ := p.format(graph, 0, encoded) - tag := fmt.Sprintf("g%dr%d", graph, oid) + str, _ := p.format(graph, 0, objAndID.obj) + tag := fmt.Sprintf("g%dr%d", graph, objAndID.id) if p.html { // Create a little tag with an anchor next to it for linking. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) diff --git a/pkg/state/state.go b/pkg/state/state.go index acb629969..6b8540f03 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error { func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. es := encodeState{ - ctx: ctx, - w: w, - types: makeTypeEncodeDatabase(), - zeroValues: make(map[reflect.Type]*objectEncodeState), + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), + pending: make(map[objectID]*objectEncodeState), + encodedStructs: make(map[reflect.Value]*wire.Struct), } // Perform the encoding. diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go index bd2c2b399..69143d194 100644 --- a/pkg/state/tests/struct.go +++ b/pkg/state/tests/struct.go @@ -54,12 +54,47 @@ type outerArray struct { } // +stateify savable +type outerSlice struct { + inner []inner +} + +// +stateify savable type inner struct { v int64 } // +stateify savable +type outerFieldValue struct { + inner innerFieldValue +} + +// +stateify savable +type innerFieldValue struct { + v int64 `state:".(*savedFieldValue)"` +} + +// +stateify savable +type savedFieldValue struct { + v int64 +} + +func (ifv *innerFieldValue) saveV() *savedFieldValue { + return &savedFieldValue{ifv.v} +} + +func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) { + ifv.v = sfv.v +} + +// +stateify savable type system struct { v1 interface{} v2 interface{} } + +// +stateify savable +type system3 struct { + v1 interface{} + v2 interface{} + v3 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go index de9d17aa7..c91c2c032 100644 --- a/pkg/state/tests/struct_test.go +++ b/pkg/state/tests/struct_test.go @@ -15,6 +15,7 @@ package tests import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/state" @@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) { } func TestEmbeddedPointers(t *testing.T) { - var ( - ofs outerSame - of1 outerFieldFirst - of2 outerFieldSecond - oa outerArray - ) + // Give each int64 a random value to prevent Go from using + // runtime.staticuint64s, which confounds tests for struct duplication. + magic := func() int64 { + for { + n := rand.Int63() + if n < 0 || n > 255 { + return n + } + } + } + + ofs := outerSame{inner{magic()}} + of1 := outerFieldFirst{inner{magic()}, magic()} + of2 := outerFieldSecond{magic(), inner{magic()}} + oa := outerArray{[2]inner{{magic()}, {magic()}}} + osl := outerSlice{oa.inner[:]} + ofv := outerFieldValue{innerFieldValue{magic()}} runTestCases(t, false, "embedded-pointers", []interface{}{ system{&ofs, &ofs.inner}, @@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) { system{&oa, &oa.inner[1]}, system{&oa.inner[0], &oa}, system{&oa.inner[1], &oa}, + system3{&oa, &oa.inner[0], &oa.inner[1]}, + system3{&oa, &oa.inner[1], &oa.inner[0]}, + system3{&oa.inner[0], &oa, &oa.inner[1]}, + system3{&oa.inner[1], &oa, &oa.inner[0]}, + system3{&oa.inner[0], &oa.inner[1], &oa}, + system3{&oa.inner[1], &oa.inner[0], &oa}, + system{&oa, &osl}, + system{&osl, &oa}, + system{&ofv, &ofv.inner}, + system{&ofv.inner, &ofv}, }) } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 12b061def..b196324c7 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -97,6 +97,9 @@ type testConnection struct { func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + return nil, err + } entry, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&entry, waiter.EventOut) @@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } // Give c.Read() a chance to block before closing the connection. @@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } c.SetDeadline(time.Now().Add(time.Minute)) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index d4d785cca..530f2ae2f 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker { } } +// IPPayload creates a checker that checks the payload. +func IPPayload(payload []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + got := h[0].Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(got) == 0 && len(payload) == 0 { + return + } + + if diff := cmp.Diff(payload, got); diff != "" { + t.Errorf("payload mismatch (-want +got):\n%s", diff) + } + } +} + // IPv4Options returns a checker that checks the options in an IPv4 packet. func IPv4Options(want []byte) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -187,7 +205,7 @@ func IPv4Options(want []byte) NetworkChecker { if !ok { t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) } - options := ip.Options() + options := []byte(ip.Options()) // cmp.Diff does not consider nil slices equal to empty slices, but we do. if len(want) == 0 && len(options) == 0 { return @@ -841,6 +859,21 @@ func ICMPv4Seq(want uint16) TransportChecker { } } +// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. +func ICMPv4Pointer(want uint8) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Pointer(); got != want { + t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) + } + } +} + // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. // This assumes that the payload exactly makes up the rest of the slice. func ICMPv4Checksum() TransportChecker { @@ -935,6 +968,38 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { } } +// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific +// field. +func ICMPv6TypeSpecific(want uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + if got := icmpv6.TypeSpecific(); got != want { + t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) + } + } +} + +// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. +func ICMPv6Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + payload := icmpv6.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 504408878..2f13dea6a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -99,7 +99,8 @@ const ( // ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. const ( - ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4ReassemblyTimeout ICMPv4Code = 1 ) // ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. @@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } +// Pointer returns the pointer field in a Parameter Problem packet. +func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] } + +// SetPointer sets the pointer field in a Parameter Problem packet. +func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } + // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 4c6e4be64..961b77628 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,7 +39,6 @@ import ( // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Options | Padding | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// const ( versIHL = 0 tos = 1 @@ -93,7 +93,7 @@ type IPv4Fields struct { DstAddr tcpip.Address } -// IPv4 represents an ipv4 header stored in a byte array. +// IPv4 is an IPv4 header. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. // Always call IsValid() to validate an instance of IPv4 before using other @@ -106,10 +106,13 @@ const ( IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given - // that there are only 4 bits to represents the header length in 32-bit - // units, the header cannot exceed 15*4 = 60 bytes. + // that there are only 4 bits (max 0xF (15)) to represent the header length + // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumOptionsSize is the largest size the IPv4 options can be. + IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. // // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 @@ -130,7 +133,7 @@ const ( // IPv4ProtocolNumber is IPv4's network protocol number. IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 - // IPv4Version is the version of the ipv4 protocol. + // IPv4Version is the version of the IPv4 protocol. IPv4Version = 4 // IPv4AllSystems is the all systems IPv4 multicast address as per @@ -148,6 +151,13 @@ const ( // packet that every IPv4 capable host must be able to // process/reassemble. IPv4MinimumProcessableDatagramSize = 576 + + // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791, + // section 3.2: + // Every internet module must be able to forward a datagram of 68 octets + // without further fragmentation. This is because an internet header may be + // up to 60 octets, and the minimum fragment is 8 octets. + IPv4MinimumMTU = 68 ) // Flags that may be set in an IPv4 packet. @@ -191,14 +201,13 @@ func IPVersion(b []byte) int { // Internet Header Length is the length of the internet header in 32 // bit words, and thus points to the beginning of the data. Note that // the minimum value for a correct header is 5. -// const ( ipVersionShift = 4 ipIHLMask = 0x0f IPv4IHLStride = 4 ) -// HeaderLength returns the value of the "header length" field of the ipv4 +// HeaderLength returns the value of the "header length" field of the IPv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { return (b[versIHL] & ipIHLMask) * IPv4IHLStride @@ -212,17 +221,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) { b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } -// ID returns the value of the identifier field of the ipv4 header. +// ID returns the value of the identifier field of the IPv4 header. func (b IPv4) ID() uint16 { return binary.BigEndian.Uint16(b[id:]) } -// Protocol returns the value of the protocol field of the ipv4 header. +// Protocol returns the value of the protocol field of the IPv4 header. func (b IPv4) Protocol() uint8 { return b[protocol] } -// Flags returns the "flags" field of the ipv4 header. +// Flags returns the "flags" field of the IPv4 header. func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } @@ -232,41 +241,44 @@ func (b IPv4) More() bool { return b.Flags()&IPv4FlagMoreFragments != 0 } -// TTL returns the "TTL" field of the ipv4 header. +// TTL returns the "TTL" field of the IPv4 header. func (b IPv4) TTL() uint8 { return b[ttl] } -// FragmentOffset returns the "fragment offset" field of the ipv4 header. +// FragmentOffset returns the "fragment offset" field of the IPv4 header. func (b IPv4) FragmentOffset() uint16 { return binary.BigEndian.Uint16(b[flagsFO:]) << 3 } -// TotalLength returns the "total length" field of the ipv4 header. +// TotalLength returns the "total length" field of the IPv4 header. func (b IPv4) TotalLength() uint16 { return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } -// Checksum returns the checksum field of the ipv4 header. +// Checksum returns the checksum field of the IPv4 header. func (b IPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[checksum:]) } -// SourceAddress returns the "source address" field of the ipv4 header. +// SourceAddress returns the "source address" field of the IPv4 header. func (b IPv4) SourceAddress() tcpip.Address { return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize]) } -// DestinationAddress returns the "destination address" field of the ipv4 +// DestinationAddress returns the "destination address" field of the IPv4 // header. func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } -// Options returns a a buffer holding the options. -func (b IPv4) Options() []byte { +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + +// Options returns a buffer holding the options. +func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() - return b[options:hdrLen:hdrLen] + return IPv4Options(b[options:hdrLen:hdrLen]) } // TransportProtocol implements Network.TransportProtocol. @@ -279,17 +291,17 @@ func (b IPv4) Payload() []byte { return b[b.HeaderLength():][:b.PayloadLength()] } -// PayloadLength returns the length of the payload portion of the ipv4 packet. +// PayloadLength returns the length of the payload portion of the IPv4 packet. func (b IPv4) PayloadLength() uint16 { return b.TotalLength() - uint16(b.HeaderLength()) } -// TOS returns the "type of service" field of the ipv4 header. +// TOS returns the "type of service" field of the IPv4 header. func (b IPv4) TOS() (uint8, uint32) { return b[tos], 0 } -// SetTOS sets the "type of service" field of the ipv4 header. +// SetTOS sets the "type of service" field of the IPv4 header. func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } @@ -299,18 +311,18 @@ func (b IPv4) SetTTL(v byte) { b[ttl] = v } -// SetTotalLength sets the "total length" field of the ipv4 header. +// SetTotalLength sets the "total length" field of the IPv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } -// SetChecksum sets the checksum field of the ipv4 header. +// SetChecksum sets the checksum field of the IPv4 header. func (b IPv4) SetChecksum(v uint16) { binary.BigEndian.PutUint16(b[checksum:], v) } // SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the -// ipv4 header. +// IPv4 header. func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { v := (uint16(flags) << 13) | (offset >> 3) binary.BigEndian.PutUint16(b[flagsFO:], v) @@ -321,23 +333,23 @@ func (b IPv4) SetID(v uint16) { binary.BigEndian.PutUint16(b[id:], v) } -// SetSourceAddress sets the "source address" field of the ipv4 header. +// SetSourceAddress sets the "source address" field of the IPv4 header. func (b IPv4) SetSourceAddress(addr tcpip.Address) { copy(b[srcAddr:srcAddr+IPv4AddressSize], addr) } -// SetDestinationAddress sets the "destination address" field of the ipv4 +// SetDestinationAddress sets the "destination address" field of the IPv4 // header. func (b IPv4) SetDestinationAddress(addr tcpip.Address) { copy(b[dstAddr:dstAddr+IPv4AddressSize], addr) } -// CalculateChecksum calculates the checksum of the ipv4 header. +// CalculateChecksum calculates the checksum of the IPv4 header. func (b IPv4) CalculateChecksum() uint16 { return Checksum(b[:b.HeaderLength()], 0) } -// Encode encodes all the fields of the ipv4 header. +// Encode encodes all the fields of the IPv4 header. func (b IPv4) Encode(i *IPv4Fields) { b.SetHeaderLength(i.IHL) b[tos] = i.TOS @@ -351,7 +363,7 @@ func (b IPv4) Encode(i *IPv4Fields) { copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) } -// EncodePartial updates the total length and checksum fields of ipv4 header, +// EncodePartial updates the total length and checksum fields of IPv4 header, // taking in the partial checksum, which is the checksum of the header without // the total length and checksum fields. It is useful in cases when similar // packets are produced. @@ -398,3 +410,424 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool { } return addr[0] == 0x7f } + +// ========================= Options ========================== + +// An IPv4OptionType can hold the valuse for the Type in an IPv4 option. +type IPv4OptionType byte + +// These constants are needed to identify individual options in the option list. +// While RFC 791 (page 31) says "Every internet module must be able to act on +// every option." This has not generally been adhered to and some options have +// very low rates of support. We do not support options other than those shown +// below. + +const ( + // IPv4OptionListEndType is the option type for the End Of Option List + // option. Anything following is ignored. + IPv4OptionListEndType IPv4OptionType = 0 + + // IPv4OptionNOPType is the No-Operation option. May appear between other + // options and may appear multiple times. + IPv4OptionNOPType IPv4OptionType = 1 + + // IPv4OptionRecordRouteType is used by each router on the path of the packet + // to record its path. It is carried over to an Echo Reply. + IPv4OptionRecordRouteType IPv4OptionType = 7 + + // IPv4OptionTimestampType is the option type for the Timestamp option. + IPv4OptionTimestampType IPv4OptionType = 68 + + // ipv4OptionTypeOffset is the offset in an option of its type field. + ipv4OptionTypeOffset = 0 + + // IPv4OptionLengthOffset is the offset in an option of its length field. + IPv4OptionLengthOffset = 1 +) + +// Potential errors when parsing generic IP options. +var ( + ErrIPv4OptZeroLength = errors.New("zero length IP option") + ErrIPv4OptDuplicate = errors.New("duplicate IP option") + ErrIPv4OptInvalid = errors.New("invalid IP option") + ErrIPv4OptMalformed = errors.New("malformed IP option") + ErrIPv4OptionTruncated = errors.New("truncated IP option") + ErrIPv4OptionAddress = errors.New("bad IP option address") +) + +// IPv4Option is an interface representing various option types. +type IPv4Option interface { + // Type returns the type identifier of the option. + Type() IPv4OptionType + + // Size returns the size of the option in bytes. + Size() uint8 + + // Contents returns a slice holding the contents of the option. + Contents() []byte +} + +var _ IPv4Option = (*IPv4OptionGeneric)(nil) + +// IPv4OptionGeneric is an IPv4 Option of unknown type. +type IPv4OptionGeneric []byte + +// Type implements IPv4Option. +func (o *IPv4OptionGeneric) Type() IPv4OptionType { + return IPv4OptionType((*o)[ipv4OptionTypeOffset]) +} + +// Size implements IPv4Option. +func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } + +// Contents implements IPv4Option. +func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } + +// IPv4OptionIterator is an iterator pointing to a specific IP option +// at any point of time. It also holds information as to a new options buffer +// that we are building up to hand back to the caller. +type IPv4OptionIterator struct { + options IPv4Options + // ErrCursor is where we are while parsing options. It is exported as any + // resulting ICMP packet is supposed to have a pointer to the byte within + // the IP packet where the error was detected. + ErrCursor uint8 + nextErrCursor uint8 + newOptions [IPv4MaximumOptionsSize]byte + writePoint int +} + +// MakeIterator sets up and returns an iterator of options. It also sets up the +// building of a new option set. +func (o IPv4Options) MakeIterator() IPv4OptionIterator { + return IPv4OptionIterator{ + options: o, + nextErrCursor: IPv4MinimumSize, + } +} + +// RemainingBuffer returns the remaining (unused) part of the new option buffer, +// into which a new option may be written. +func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { + return IPv4Options(i.newOptions[i.writePoint:]) +} + +// ConsumeBuffer marks a portion of the new buffer as used. +func (i *IPv4OptionIterator) ConsumeBuffer(size int) { + i.writePoint += size +} + +// PushNOPOrEnd puts one of the single byte options onto the new options. +// Only values 0 or 1 (ListEnd or NOP) are valid input. +func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) { + if val > IPv4OptionNOPType { + panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val)) + } + i.newOptions[i.writePoint] = byte(val) + i.writePoint++ +} + +// Finalize returns the completed replacement options buffer padded +// as needed. +func (i *IPv4OptionIterator) Finalize() IPv4Options { + // RFC 791 page 31 says: + // The options might not end on a 32-bit boundary. The internet header + // must be filled out with octets of zeros. The first of these would + // be interpreted as the end-of-options option, and the remainder as + // internet header padding. + // Since the buffer is already zero filled we just need to step the write + // pointer up to the next multiple of 4. + options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3]) + // Poison the write pointer. + i.writePoint = len(i.newOptions) + return options +} + +// Next returns the next IP option in the buffer/list of IP options. +// It returns +// - A slice of bytes holding the next option or nil if there is error. +// - A boolean which is true if parsing of all the options is complete. +// - An error which is non-nil if an error condition was encountered. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { + // The opts slice gets shorter as we process the options. When we have no + // bytes left we are done. + if len(i.options) == 0 { + return nil, true, nil + } + + i.ErrCursor = i.nextErrCursor + + optType := IPv4OptionType(i.options[ipv4OptionTypeOffset]) + + if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType { + optionBody := i.options[:1] + i.options = i.options[1:] + i.nextErrCursor = i.ErrCursor + 1 + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil + } + + // There are no more single byte options defined. All the rest have a length + // field so we need to sanity check it. + if len(i.options) == 1 { + return nil, true, ErrIPv4OptMalformed + } + + optLen := i.options[IPv4OptionLengthOffset] + + if optLen == 0 { + i.ErrCursor++ + return nil, true, ErrIPv4OptZeroLength + } + + if optLen == 1 { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + + if optLen > uint8(len(i.options)) { + i.ErrCursor++ + return nil, true, ErrIPv4OptionTruncated + } + + optionBody := i.options[:optLen] + i.nextErrCursor = i.ErrCursor + optLen + i.options = i.options[optLen:] + + // Check the length of some option types that we know. + switch optType { + case IPv4OptionTimestampType: + if optLen < IPv4OptionTimestampHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionTimestamp(optionBody) + return &retval, false, nil + + case IPv4OptionRecordRouteType: + if optLen < IPv4OptionRecordRouteHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionRecordRoute(optionBody) + return &retval, false, nil + } + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil +} + +// +// IP Timestamp option - RFC 791 page 22. +// +--------+--------+--------+--------+ +// |01000100| length | pointer|oflw|flg| +// +--------+--------+--------+--------+ +// | internet address | +// +--------+--------+--------+--------+ +// | timestamp | +// +--------+--------+--------+--------+ +// | ... | +// +// Type = 68 +// +// The Option Length is the number of octets in the option counting +// the type, length, pointer, and overflow/flag octets (maximum +// length 40). +// +// The Pointer is the number of octets from the beginning of this +// option to the end of timestamps plus one (i.e., it points to the +// octet beginning the space for next timestamp). The smallest +// legal value is 5. The timestamp area is full when the pointer +// is greater than the length. +// +// The Overflow (oflw) [4 bits] is the number of IP modules that +// cannot register timestamps due to lack of space. +// +// The Flag (flg) [4 bits] values are +// +// 0 -- time stamps only, stored in consecutive 32-bit words, +// +// 1 -- each timestamp is preceded with internet address of the +// registering entity, +// +// 3 -- the internet address fields are prespecified. An IP +// module only registers its timestamp if it matches its own +// address with the next specified internet address. +// +// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC. +// +// The Timestamp is a right-justified, 32-bit timestamp in +// milliseconds since midnight UT. If the time is not available in +// milliseconds or cannot be provided with respect to midnight UT +// then any time may be inserted as a timestamp provided the high +// order bit of the timestamp field is set to one to indicate the +// use of a non-standard value. + +// IPv4OptTSFlags sefines the values expected in the Timestamp +// option Flags field. +type IPv4OptTSFlags uint8 + +// +// Timestamp option specific related constants. +const ( + // IPv4OptionTimestampHdrLength is the length of the timestamp option header. + IPv4OptionTimestampHdrLength = 4 + + // IPv4OptionTimestampSize is the size of an IP timestamp. + IPv4OptionTimestampSize = 4 + + // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address. + IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize + + // IPv4OptionTimestampMaxSize is limited by space for options + IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize + + // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp + // is present. + IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0 + + // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and + // IP are present. + IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1 + + // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that + // predefined IP is present. + IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3 +) + +// ipv4TimestampTime provides the current time as specified in RFC 791. +func ipv4TimestampTime(clock tcpip.Clock) uint32 { + const millisecondsPerDay = 24 * 3600 * 1000 + const nanoPerMilli = 1000000 + return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) +} + +// IP Timestamp option fields. +const ( + // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field. + IPv4OptTSPointerOffset = 2 + + // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow + // fields, (each being 4 bits). + IPv4OptTSOFLWAndFLGOffset = 3 + // These constants define the sub byte fields of the Flag and OverFlow field. + ipv4OptionTimestampOverflowshift = 4 + ipv4OptionTimestampFlagsMask byte = 0x0f +) + +var _ IPv4Option = (*IPv4OptionTimestamp)(nil) + +// IPv4OptionTimestamp is a Timestamp option from RFC 791. +type IPv4OptionTimestamp []byte + +// Type implements IPv4Option.Type(). +func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType } + +// Size implements IPv4Option. +func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } + +// Contents implements IPv4Option. +func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } + +// Pointer returns the pointer field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Pointer() uint8 { + return (*ts)[IPv4OptTSPointerOffset] +} + +// Flags returns the flags field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags { + return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask) +} + +// Overflow returns the Overflow field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Overflow() uint8 { + return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift +} + +// IncOverflow increments the Overflow field in the IP Timestamp option. It +// returns the incremented value. If the return value is 0 then the field +// overflowed. +func (ts *IPv4OptionTimestamp) IncOverflow() uint8 { + (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift + return ts.Overflow() +} + +// UpdateTimestamp updates the fields of the next free timestamp slot. +func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) { + slot := (*ts)[ts.Pointer()-1:] + + switch ts.Flags() { + case IPv4OptionTimestampOnlyFlag: + binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize + case IPv4OptionTimestampWithIPFlag: + if n := copy(slot, addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + case IPv4OptionTimestampWithPredefinedIPFlag: + if tcpip.Address(slot[:IPv4AddressSize]) == addr { + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + } + } +} + +// RecordRoute option specific related constants. +// +// from RFC 791 page 20: +// Record Route +// +// +--------+--------+--------+---------//--------+ +// |00000111| length | pointer| route data | +// +--------+--------+--------+---------//--------+ +// Type=7 +// +// The record route option provides a means to record the route of +// an internet datagram. +// +// The option begins with the option type code. The second octet +// is the option length which includes the option type code and the +// length octet, the pointer octet, and length-3 octets of route +// data. The third octet is the pointer into the route data +// indicating the octet which begins the next area to store a route +// address. The pointer is relative to this option, and the +// smallest legal value for the pointer is 4. +const ( + // IPv4OptionRecordRouteHdrLength is the length of the Record Route option + // header. + IPv4OptionRecordRouteHdrLength = 3 + + // IPv4OptRRPointerOffset is the offset to the pointer field in an RR + // option, which points to the next free slot in the list of addresses. + IPv4OptRRPointerOffset = 2 +) + +var _ IPv4Option = (*IPv4OptionRecordRoute)(nil) + +// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791. +type IPv4OptionRecordRoute []byte + +// Pointer returns the pointer field in the IP RecordRoute option. +func (rr *IPv4OptionRecordRoute) Pointer() uint8 { + return (*rr)[IPv4OptRRPointerOffset] +} + +// StoreAddress stores the given IPv4 address into the next free slot. +func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) { + start := rr.Pointer() - 1 // A one based number. + // start and room checked by caller. + if n := copy((*rr)[start:], addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize +} + +// Type implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType } + +// Size implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } + +// Contents implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index c5d8a3456..09cb153b1 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -101,8 +101,10 @@ const ( // The address is ff02::2. IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, - // section 5. + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, + // section 5: + // IPv6 requires that every link in the Internet have an MTU of 1280 octets + // or greater. This is known as the IPv6 minimum link MTU. IPv6MinimumMTU = 1280 // IPv6Loopback is the IPv6 Loopback address. diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD new file mode 100644 index 000000000..ec92ed623 --- /dev/null +++ b/pkg/tcpip/link/ethernet/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "ethernet", + srcs = ["ethernet.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go new file mode 100644 index 000000000..3eef7cd56 --- /dev/null +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ethernet provides an implementation of an ethernet link endpoint that +// wraps an inner link endpoint. +package ethernet + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkDispatcher = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) + +// New returns an ethernet link endpoint that wraps an inner link endpoint. +func New(ep stack.LinkEndpoint) *Endpoint { + var e Endpoint + e.Endpoint.Init(ep, &e) + return &e +} + +// Endpoint is an ethernet endpoint. +// +// It adds an ethernet header to packets before sending them out through its +// inner link endpoint and consumes an ethernet header before sending the +// packet to the stack. +type Endpoint struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + return + } + + eth := header.Ethernet(hdr) + if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt) + } +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + return e.Endpoint.WritePacket(r, gso, proto, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + linkAddr := e.Endpoint.LinkAddress() + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength() +} + +// ARPHardwareType implements stack.LinkEndpoint. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} + +// AddHeader implements stack.LinkEndpoint. +func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + fields := header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: proto, + } + eth.Encode(&fields) +} diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 76f563811..523b0d24b 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -26,27 +26,23 @@ import ( var _ stack.LinkEndpoint = (*Endpoint)(nil) // New returns both ends of a new pipe. -func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) { +func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) { ep1 := &Endpoint{ - linkAddr: linkAddr1, - capabilities: capabilities, + linkAddr: linkAddr1, } ep2 := &Endpoint{ - linkAddr: linkAddr2, - linked: ep1, - capabilities: capabilities, + linkAddr: linkAddr2, } ep1.linked = ep2 + ep2.linked = ep1 return ep1, ep2 } // Endpoint is one end of a pipe. type Endpoint struct { - capabilities stack.LinkEndpointCapabilities - linkAddr tcpip.LinkAddress - dispatcher stack.NetworkDispatcher - linked *Endpoint - onWritePacket func(*stack.PacketBuffer) + dispatcher stack.NetworkDispatcher + linked *Endpoint + linkAddr tcpip.LinkAddress } // WritePacket implements stack.LinkEndpoint. @@ -55,16 +51,11 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network return nil } - // The pipe endpoint will accept all multicast/broadcast link traffic and only - // unicast traffic destined to itself. - if len(e.linked.linkAddr) != 0 && - r.RemoteLinkAddress != e.linked.linkAddr && - r.RemoteLinkAddress != header.EthernetBroadcastAddress && - !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) { - return nil - } - - e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // Note that the local address from the perspective of this endpoint is the + // remote address from the perspective of the other end of the pipe + // (e.linked). Similarly, the remote address from the perspective of this + // endpoint is the local address on the other end. + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), })) @@ -100,8 +91,8 @@ func (*Endpoint) MTU() uint32 { } // Capabilities implements stack.LinkEndpoint. -func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities +func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 } // MaxHeaderLength implements stack.LinkEndpoint. @@ -116,7 +107,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // ARPHardwareType implements stack.LinkEndpoint. func (*Endpoint) ARPHardwareType() header.ARPHardwareType { - return header.ARPHardwareEther + return header.ARPHardwareNone } // AddHeader implements stack.LinkEndpoint. diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index dc239a0d0..2777f1411 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) { const count = 1000000 var wg sync.WaitGroup + defer wg.Wait() wg.Add(1) go func() { defer wg.Done() @@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) { } }() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } + for i := 0; i < count; i++ { + n := 1 + rr.Intn(80) + rb := rx.Pull() + for rb == nil { + rb = rx.Pull() + } - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } + if n != len(rb) { + t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) + } - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } + for j := range rb { + if v := byte(rr.Intn(256)); v != rb[j] { + t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) } - - rx.Flush() } - }() - wg.Wait() + rx.Flush() + } } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 0243424f6..86f14db76 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "tun_endpoint_refs.go", package = "tun", prefix = "tunEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tunEndpoint", }, @@ -28,6 +28,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f94491026..cda6328a2 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -150,7 +150,6 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // 2. Creating a new NIC. id := tcpip.NICID(s.UniqueID()) - // TODO(gvisor.dev/1486): enable leak check for tunEndpoint. endpoint := &tunEndpoint{ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), stack: s, @@ -158,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE name: name, isTap: prefix == "tap", } + endpoint.EnableLeakCheck() endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 59710352b..c118a2929 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -12,6 +12,7 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index b40dde96b..8a6bcfc2c 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -30,5 +30,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7df77c66e..a79379abb 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -18,6 +18,7 @@ package arp import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -150,28 +151,36 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - // As per RFC 826, under Packet Reception: - // Swap hardware and protocol fields, putting the local hardware and - // protocol addresses in the sender fields. - // - // Send the packet to the (new) target hardware address on the same - // hardware on which the request was received. - origSender := h.HardwareAddressSender() - r.RemoteLinkAddress = tcpip.LinkAddress(origSender) respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), origSender) - copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress()) + if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + origSender := h.HardwareAddressSender() + if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize)) + } + if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -199,6 +208,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { + stack *stack.Stack } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -227,26 +237,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - r := &stack.Route{ - NetProto: ProtocolNumber, - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + if len(remoteLinkAddr) == 0 { + remoteLinkAddr = header.EthernetBroadcastAddress } - if len(r.RemoteLinkAddress) == 0 { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + + nicID := nic.ID() + if len(localAddr) == 0 { + addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if err != nil { + return err + } + + if len(addr.Address) == 0 { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addr.Address + } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + pkt.NetworkProtocolNumber = ProtocolNumber h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), linkEP.LinkAddress()) - copy(h.ProtocolAddressSender(), localAddr) - copy(h.ProtocolAddressTarget(), addr) - - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -286,6 +314,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu // Note, to make sure that the ARP endpoint receives ARP packets, the "arp" // address must be added to every NIC that should respond to ARP requests. See // ProtocolAddress for more details. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 626af975a..bf1292bb8 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,11 @@ func (t eventType) String() string { type eventInfo struct { eventType eventType nicID tcpip.NICID - addr tcpip.Address - linkAddr tcpip.LinkAddress - state stack.NeighborState + entry stack.NeighborEntry } func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) } // arpDispatcher implements NUDDispatcher to validate the dispatching of @@ -96,35 +95,29 @@ type arpDispatcher struct { var _ stack.NUDDispatcher = (*arpDispatcher)(nil) -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryChanged, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryRemoved, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } @@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" { + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { return fmt.Errorf("got invalid event (-got +want):\n%s", diff) } case <-ctx.Done(): @@ -373,9 +366,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { wantEvent := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: test.senderAddr, - linkAddr: tcpip.LinkAddress(test.senderLinkAddr), - state: stack.Stale, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + State: stack.Stale, + }, } if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { t.Fatal(err) @@ -404,9 +399,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) } - if got, want := neigh.LocalAddr, stackAddr; got != want { - t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want) - } if got, want := neigh.State, stack.Stale; got != want { t.Errorf("got neighbor State = %s, want = %s", got, want) } @@ -423,43 +415,164 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.LinkEndpoint + + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + + testAddr := tcpip.Address([]byte{1, 2, 3, 4}) + tests := []struct { name string + nicAddr tcpip.Address + localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Unicast with no local address available", remoteLinkAddr: remoteLinkAddr, - expectLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, }, { - name: "Multicast", + name: "Multicast with no local address available", remoteLinkAddr: "", - expectLinkAddr: header.EthernetBroadcastAddress, + expectedErr: tcpip.ErrNetworkUnreachable, }, } for _, test := range tests { - p := arp.NewProtocol(nil) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err) - } + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + } + } - if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) - } + // We pass a test network interface to LinkAddressRequest with the same + // NIC ID and link endpoint used by the NIC we created earlier so that we + // can mock a link address request and observe the packets sent to the + // link endpoint even though the stack uses the real NIC to validate the + // local address. + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + } + + rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) + } + }) } } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index ed502a473..936601287 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -136,8 +136,16 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // proto is the protocol number marked in the fragment being processed. It has // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. +// +// releaseCB is a callback that will run when the fragment reassembly of a +// packet is complete or cancelled. releaseCB take a a boolean argument which is +// true iff the reassembly is cancelled due to timeout. releaseCB should be +// passed only with the first fragment of a packet. If more than one releaseCB +// are passed for the same packet, only the first releaseCB will be saved for +// the packet and the succeeding ones will be dropped by running them +// immediately with a false argument. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) ( + id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -171,6 +179,12 @@ func (f *Fragmentation) Process( f.releaseReassemblersLocked() } } + if releaseCB != nil { + if !r.setCallback(releaseCB) { + // We got a duplicate callback. Release it immediately. + releaseCB(false /* timedOut */) + } + } f.mu.Unlock() res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) @@ -178,14 +192,14 @@ func (f *Fragmentation) Process( // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() - f.release(r) + f.release(r, false /* timedOut */) f.mu.Unlock() return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed if done { - f.release(r) + f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. @@ -195,14 +209,14 @@ func (f *Fragmentation) Process( if tail == nil { break } - f.release(tail) + f.release(tail, false /* timedOut */) } } f.mu.Unlock() return res, firstFragmentProto, done, nil } -func (f *Fragmentation) release(r *reassembler) { +func (f *Fragmentation) release(r *reassembler, timedOut bool) { // Before releasing a fragment we need to check if r is already marked as done. // Otherwise, we would delete it twice. if r.checkDoneOrMark() { @@ -216,6 +230,8 @@ func (f *Fragmentation) release(r *reassembler) { log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) f.size = 0 } + + r.release(timedOut) // releaseCB may run. } // releaseReassemblersLocked releases already-expired reassemblers, then @@ -238,31 +254,31 @@ func (f *Fragmentation) releaseReassemblersLocked() { break } // If the oldest reassembler has already expired, release it. - f.release(r) + f.release(r, true /* timedOut*/) } } // PacketFragmenter is the book-keeping struct for packet fragmentation. type PacketFragmenter struct { - transportHeader buffer.View - data buffer.VectorisedView - reserve int - innerMTU int - fragmentCount int - currentFragment int - fragmentOffset int + transportHeader buffer.View + data buffer.VectorisedView + reserve int + fragmentPayloadLen int + fragmentCount int + currentFragment int + fragmentOffset int } // MakePacketFragmenter prepares the struct needed for packet fragmentation. // // pkt is the packet to be fragmented. // -// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can // have. // // reserve is the number of bytes that should be reserved for the headers in // each generated fragment. -func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { +func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be // repeated in each fragment. However we do not currently support any header // of that kind yet, so the following computation is valid for both IPv4 and @@ -273,13 +289,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa var fragmentableData buffer.VectorisedView fragmentableData.AppendView(pkt.TransportHeader().View()) fragmentableData.Append(pkt.Data) - fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen return PacketFragmenter{ - data: fragmentableData, - reserve: reserve, - innerMTU: innerMTU, - fragmentCount: fragmentCount, + data: fragmentableData, + reserve: reserve, + fragmentPayloadLen: int(fragmentPayloadLen), + fragmentCount: int(fragmentCount), } } @@ -302,7 +318,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, }) // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) offset := pf.fragmentOffset pf.fragmentOffset += copied diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index d3c7d7f92..5dcd10730 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -105,7 +105,7 @@ func TestFragmentationProcess(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) @@ -240,7 +240,7 @@ func TestReassemblingTimeout(t *testing.T) { for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -259,15 +259,15 @@ func TestReassemblingTimeout(t *testing.T) { func TestMemoryLimits(t *testing.T) { f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -283,9 +283,9 @@ func TestMemoryLimits(t *testing.T) { func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) got := f.size want := 1 @@ -377,7 +377,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -403,14 +403,14 @@ func TestPacketFragmenter(t *testing.T) { tests := []struct { name string - innerMTU int + fragmentPayloadLen uint32 transportHeaderLen int payloadSize int wantFragments []fragmentInfo }{ { name: "Packet exactly fits in MTU", - innerMTU: 1280, + fragmentPayloadLen: 1280, transportHeaderLen: 0, payloadSize: 1280, wantFragments: []fragmentInfo{ @@ -419,7 +419,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet exactly does not fit in MTU", - innerMTU: 1000, + fragmentPayloadLen: 1000, transportHeaderLen: 0, payloadSize: 1001, wantFragments: []fragmentInfo{ @@ -429,7 +429,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a transport header", - innerMTU: 560, + fragmentPayloadLen: 560, transportHeaderLen: 40, payloadSize: 560, wantFragments: []fragmentInfo{ @@ -439,7 +439,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a huge transport header", - innerMTU: 500, + fragmentPayloadLen: 500, transportHeaderLen: 1300, payloadSize: 500, wantFragments: []fragmentInfo{ @@ -458,7 +458,7 @@ func TestPacketFragmenter(t *testing.T) { originalPayload.AppendView(pkt.TransportHeader().View()) originalPayload.Append(pkt.Data) var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) for i := 0; ; i++ { fragPkt, offset, copied, more := pf.BuildNextFragment() wantFragment := test.wantFragments[i] @@ -474,8 +474,8 @@ func TestPacketFragmenter(t *testing.T) { if more != wantFragment.more { t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) } - if got := fragPkt.Size(); got > test.innerMTU { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) } if got := fragPkt.AvailableHeaderBytes(); got != reserve { t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) @@ -497,3 +497,89 @@ func TestPacketFragmenter(t *testing.T) { }) } } + +func TestReleaseCallback(t *testing.T) { + const ( + proto = 99 + ) + + var result int + var callbackReasonIsTimeout bool + cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + + tests := []struct { + name string + callbacks []func(bool) + timeout bool + wantResult int + wantCallbackReasonIsTimeout bool + }{ + { + name: "callback runs on release", + callbacks: []func(bool){cb1}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "first callback is nil", + callbacks: []func(bool){nil, cb2}, + timeout: false, + wantResult: 2, + wantCallbackReasonIsTimeout: false, + }, + { + name: "two callbacks - first one is set", + callbacks: []func(bool){cb1, cb2}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "callback runs on timeout", + callbacks: []func(bool){cb1}, + timeout: true, + wantResult: 1, + wantCallbackReasonIsTimeout: true, + }, + { + name: "no callbacks", + callbacks: []func(bool){nil}, + timeout: false, + wantResult: 0, + wantCallbackReasonIsTimeout: false, + }, + } + + id := FragmentID{ID: 0} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result = 0 + callbackReasonIsTimeout = false + + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + + for i, cb := range test.callbacks { + _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) + if err != nil { + t.Errorf("f.Process error = %s", err) + } + } + + r, ok := f.reassemblers[id] + if !ok { + t.Fatalf("Reassemberr not found") + } + f.release(r, test.timeout) + + if result != test.wantResult { + t.Errorf("got result = %d, want = %d", result, test.wantResult) + } + if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { + t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9bb051a30..c0cc0bde0 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -41,6 +41,7 @@ type reassembler struct { heap fragHeap done bool creationTime int64 + callback func(bool) } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { @@ -123,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } + +func (r *reassembler) setCallback(c func(bool)) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.callback != nil { + return false + } + r.callback = c + return true +} + +func (r *reassembler) release(timedOut bool) { + r.mu.Lock() + callback := r.callback + r.callback = nil + r.mu.Unlock() + + if callback != nil { + callback(timedOut) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index a0a04a027..fa2a70dc8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -105,3 +105,26 @@ func TestUpdateHoles(t *testing.T) { } } } + +func TestSetCallback(t *testing.T) { + result := 0 + reasonTimeout := false + + cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } + + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + if !r.setCallback(cb1) { + t.Errorf("setCallback failed") + } + if r.setCallback(cb2) { + t.Errorf("setCallback should fail if one is already set") + } + r.release(true) + if result != 1 { + t.Errorf("got result = %d, want = 1", result) + } + if !reasonTimeout { + t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d436873b6..969579601 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,11 +15,13 @@ package ip_test import ( + "strings" "testing" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -302,6 +304,10 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -320,6 +326,7 @@ func TestSourceAddressValidation(t *testing.T) { SrcAddr: src, DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -342,7 +349,6 @@ func TestSourceAddressValidation(t *testing.T) { SrcAddr: src, DstAddr: localIPv6Addr, }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -579,6 +585,7 @@ func TestIPv4Receive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -660,6 +667,7 @@ func TestIPv4ReceiveControl(t *testing.T) { SrcAddr: "\x0a\x00\x00\xbb", DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Create the ICMP header. icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) @@ -679,12 +687,17 @@ func TestIPv4ReceiveControl(t *testing.T) { SrcAddr: localIPv4Addr, DstAddr: remoteIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { view[i] = uint8(i) } + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) + // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. nic.testObject.protocol = 10 @@ -732,6 +745,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip1.SetChecksum(^ip1.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag1[i] = uint8(i) @@ -748,6 +763,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip2.SetChecksum(^ip2.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag2[i] = uint8(i) @@ -1020,3 +1037,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer _, _ = pkt.NetworkHeader().Consume(netHdrLen) return pkt } + +func TestWriteHeaderIncludedPacket(t *testing.T) { + const ( + nicID = 1 + transportProto = 5 + + dataLen = 4 + optionsLen = 4 + ) + + dataBuf := [dataLen]byte{1, 2, 3, 4} + data := dataBuf[:] + + ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1} + ipv4Options := ipv4OptionsBuf[:] + + ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} + ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] + + var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte + ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] + if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + nicAddr tcpip.Address + remoteAddr tcpip.Address + pktGen func(*testing.T, tcpip.Address) buffer.View + checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) + expectedErr *tcpip.Error + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with IHL too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 1, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 minimum size", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(header.IPv4MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv4 with options", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) + totalLen := ipHdrLen + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(ipHdrLen)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHdrLen), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) + } + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6 with extension header", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), + checker.IPPayload(ipv6PayloadWithExtHdr), + ) + }, + }, + { + name: "IPv6 minimum size", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(header.IPv6MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv6 too small", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + srcAddr tcpip.Address + }{ + { + name: "unspecified source", + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + }, + { + name: "random source", + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + e := channel.New(1, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) + + r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + } + defer r.Release() + + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + })); err != test.expectedErr { + t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + test.checker(t, pkt.Pkt, subTest.srcAddr) + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7fc12e229..6252614ec 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -29,6 +29,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3407755ed..cf287446e 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -23,10 +24,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// handleControl handles the case when an ICMP packet contains the headers of -// the original packet that caused the ICMP one to be sent. This information is -// used to find out which transport endpoint must be notified about the ICMP -// packet. +// handleControl handles the case when an ICMP error packet contains the headers +// of the original packet that caused the ICMP one to be sent. This information +// is used to find out which transport endpoint must be notified about the ICMP +// packet. We only expect the payload, not the enclosing ICMP packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { @@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } h := header.ICMPv4(v) + // Only do in-stack processing if the checksum is correct. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + received.Invalid.Increment() + // It's possible that a raw socket expects to receive this regardless + // of checksum errors. If it's an echo request we know it's safe because + // we are the only handler, however other types do not cope well with + // packets with checksum errors. + switch h.Type() { + case header.ICMPv4Echo: + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + } + return + } + + iph := header.IPv4(pkt.NetworkHeader().View()) + var newOptions header.IPv4Options + if len(iph) > header.IPv4MinimumSize { + // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip + // type ICMP packets): + // If a Record Route and/or Time Stamp option is received in an + // ICMP Echo Request, this option (these options) SHOULD be + // updated to include the current host and included in the IP + // header of the Echo Reply message, without "truncation". + // Thus, the recorded route will be for the entire round trip. + // + // So we need to let the option processor know how it should handle them. + var op optionsUsage + if h.Type() == header.ICMPv4Echo { + op = &optionUsageEcho{} + } else { + op = &optionUsageReceive{} + } + aux, tmp, err := processIPOptions(r, iph.Options(), op) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + } + return + } + newOptions = tmp + } + // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - // Only send a reply if the checksum is valid. - headerChecksum := h.Checksum() - h.SetChecksum(0) - calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(headerChecksum) - if calculatedChecksum != headerChecksum { - // It's possible that a raw socket still expects to receive this. - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - received.Invalid.Increment() + sent := stats.ICMP.V4PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() return } @@ -98,9 +144,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. replyData := pkt.Data.ToOwnedView() - replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) + // It's possible that a raw socket expects to receive this. e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + pkt = nil + // Take the base of the incoming request IP header but replace the options. + replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) + replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) + replyIPHdr.SetHeaderLength(replyHeaderLength) // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast @@ -139,7 +190,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // The fields we need to alter. // // We need to produce the entire packet in the data segment in order to - // use WriteHeaderIncludedPacket(). + // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the + // total length and the header checksum so we don't need to set those here. replyIPHdr.SetSourceAddress(r.LocalAddress) replyIPHdr.SetDestinationAddress(r.RemoteAddress) replyIPHdr.SetTTL(r.DefaultTTL()) @@ -157,8 +209,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { }) replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // The checksum will be calculated so we don't need to do it here. - sent := stats.ICMP.V4PacketsSent if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return @@ -182,8 +232,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: - mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) } case header.ICMPv4SrcQuench: @@ -234,6 +287,21 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + +// icmpReasonParamProblem is an error to use to request a Parameter Problem +// message to be sent. +type icmpReasonParamProblem struct { + pointer byte +} + +func (*icmpReasonParamProblem) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -374,17 +442,29 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - switch reason.(type) { + var counter *tcpip.StatCounter + switch reason := reason.(type) { case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) + counter = sent.DstUnreachable case *icmpReasonProtoUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) + counter = sent.TimeExceeded + case *icmpReasonParamProblem: + icmpHdr.SetType(header.ICMPv4ParamProblem) + icmpHdr.SetCode(header.ICMPv4UnusedCode) + icmpHdr.SetPointer(reason.pointer) + counter = sent.ParamProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - counter := sent.DstUnreachable if err := route.WritePacket( nil, /* gso */ diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c5ac7b8b5..4592984a5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -16,7 +16,9 @@ package ipv4 import ( + "errors" "fmt" + "math" "sync/atomic" "time" @@ -31,6 +33,8 @@ import ( ) const ( + // ReassembleTimeout is the time a packet stays in the reassembly + // system before being evicted. // As per RFC 791 section 3.2: // The current recommendation for the initial timer setting is 15 seconds. // This may be changed as experience with this protocol accumulates. @@ -38,7 +42,7 @@ const ( // Considering that it is an old recommendation, we use the same reassembly // timeout that linux defines, which is 30 seconds: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 - reassembleTimeout = 30 * time.Second + ReassembleTimeout = 30 * time.Second // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -176,7 +180,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and @@ -190,29 +198,6 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -// writePacketFragments fragments pkt and writes the results on the link -// endpoint. The IP header must already present in the original packet. The mtu -// is the maximum size of the packets. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error { - networkHeader := header.IPv4(pkt.NetworkHeader().View()) - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) - - for { - fragPkt, more := buildNextFragment(&pf, networkHeader) - if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) - return err - } - r.Stats().IP.PacketsSent.Increment() - if !more { - break - } - } - - return nil -} - func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) length := uint16(pkt.Size()) @@ -234,10 +219,36 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + // Round the MTU down to align to 8 bytes. + fragmentPayloadSize := networkMTU &^ 7 + networkHeader := header.IPv4(pkt.NetworkHeader().View()) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader)) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + return n, pf.RemainingFragmentCount(), nil + } + } +} + // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -273,9 +284,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) + + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + return err + } + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err @@ -293,9 +321,29 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } - for pkt := pkts.Front(); pkt != nil; { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.addIPHeader(r, pkt, params) - pkt = pkt.Next() + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pkt + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(pkt, fragPkt) + pkt = fragPkt + return nil + }); err != nil { + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err)) + } + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) + } } nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -347,30 +395,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacket writes a packet already containing a network -// header through the given route. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrInvalidOptionValue + return tcpip.ErrMalformedHeader } ip := header.IPv4(h) - if !ip.IsValid(pkt.Data.Size()) { - return tcpip.ErrInvalidOptionValue - } // Always set the total length. - ip.SetTotalLength(uint16(pkt.Data.Size())) + pktSize := pkt.Data.Size() + ip.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + if ip.SourceAddress() == header.IPv4Any { ip.SetSourceAddress(r.LocalAddress) } - // Set the destination. If the packet already included a destination, - // it will be part of the route. + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. ip.SetDestinationAddress(r.RemoteAddress) // Set the packet ID when zero. @@ -387,19 +432,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if r.Loop&stack.PacketLoop != 0 { - e.HandlePacket(r, pkt.Clone()) - } - if r.Loop&stack.PacketOut == 0 { - return nil + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader } - if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return err - } - r.Stats().IP.PacketsSent.Increment() - return nil + return e.writePacket(r, nil /* gso */, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -415,6 +458,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + // As per RFC 1122 section 3.2.1.3: // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or @@ -455,6 +524,28 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { r.Stats().IP.MalformedFragmentsReceived.Increment() return } + + // Set up a callback in case we need to send a Time Exceeded Message, as per + // RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards + // the datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + r := r.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + } + r.Release() + } + } + var ready bool var err error proto := h.Protocol() @@ -472,6 +563,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { h.More(), proto, pkt.Data, + releaseCB, ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -481,9 +573,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if !ready { return } - } + // The reassembler doesn't take care of fixing up the header, so we need + // to do it here. + h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetFlagsFragmentOffset(0, 0) + } r.Stats().IP.PacketsDelivered.Increment() + p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport @@ -493,6 +590,27 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { e.handleICMP(r, pkt) return } + if len(h.Options()) != 0 { + // TODO(gvisor.dev/issue/4586): + // When we add forwarding support we should use the verified options + // rather than just throwing them away. + aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{}) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + } + return + } + } switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { case stack.TransportPacketHandled: @@ -727,26 +845,32 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload mtu. +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv4MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState } - return mtu - header.IPv4MinimumSize -} -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize + // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in + // length: + // The maximal internet header is 60 octets, and a typical internet header + // is 20 octets, allowing a margin for headers of higher level protocols. + if networkHeaderSize > header.IPv4MaximumHeaderSize { + return 0, tcpip.ErrMalformedHeader } - mtu -= uint32(pkt.NetworkHeader().View().Size()) - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - return mtu + + networkMTU := linkMTU + if networkMTU > MaxTotalSize { + networkMTU = MaxTotalSize + } + + return networkMTU - uint32(networkHeaderSize), nil +} + +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // addressToUint32 translates an IPv4 address into its little endian uint32 @@ -785,7 +909,7 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), } } @@ -811,3 +935,322 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head return fragPkt, more } + +// optionAction describes possible actions that may be taken on an option +// while processing it. +type optionAction uint8 + +const ( + // optionRemove says that the option should not be in the output option set. + optionRemove optionAction = iota + + // optionProcess says that the option should be fully processed. + optionProcess + + // optionVerify says the option should be checked and passed unchanged. + optionVerify + + // optionPass says to pass the output set without checking. + optionPass +) + +// optionActions list what to do for each option in a given scenario. +type optionActions struct { + // timestamp controls what to do with a Timestamp option. + timestamp optionAction + + // recordroute controls what to do with a Record Route option. + recordRoute optionAction + + // unknown controls what to do with an unknown option. + unknown optionAction +} + +// optionsUsage specifies the ways options may be operated upon for a given +// scenario during packet processing. +type optionsUsage interface { + actions() optionActions +} + +// optionUsageReceive implements optionsUsage for received packets. +type optionUsageReceive struct{} + +// actions implements optionsUsage. +func (*optionUsageReceive) actions() optionActions { + return optionActions{ + timestamp: optionVerify, + recordRoute: optionVerify, + unknown: optionPass, + } +} + +// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it +// is enabled (Process, Process, Pass) and for fragmenting (Process, Process, +// Pass for frag1, but Remove,Remove,Remove for all other frags). + +// optionUsageEcho implements optionsUsage for echo packet processing. +type optionUsageEcho struct{} + +// actions implements optionsUsage. +func (*optionUsageEcho) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, + unknown: optionRemove, + } +} + +var ( + errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") + errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") + errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") + errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") +) + +// handleTimestamp does any required processing on a Timestamp option +// in place. +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { + flags := tsOpt.Flags() + var entrySize uint8 + switch flags { + case header.IPv4OptionTimestampOnlyFlag: + entrySize = header.IPv4OptionTimestampSize + case + header.IPv4OptionTimestampWithIPFlag, + header.IPv4OptionTimestampWithPredefinedIPFlag: + entrySize = header.IPv4OptionTimestampWithAddrSize + default: + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + } + + pointer := tsOpt.Pointer() + // To simplify processing below, base further work on the array of timestamps + // beyond the header, rather than on the whole option. Also to aid + // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. + nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1) + optLen := tsOpt.Size() + dataLength := optLen - header.IPv4OptionTimestampHdrLength + + // In the section below, we verify the pointer, length and overflow counter + // fields of the option. The distinction is in which byte you return as being + // in error in the ICMP packet. Offsets 1 (length), 2 pointer) + // or 3 (overflowed counter). + // + // The following RFC sections cover this section: + // + // RFC 791 (page 22): + // If there is some room but not enough room for a full timestamp + // to be inserted, or the overflow count itself overflows, the + // original datagram is considered to be in error and is discarded. + // In either case an ICMP parameter problem message may be sent to + // the source host [3]. + // + // You can get this situation in two ways. Firstly if the data area is not + // a multiple of the entry size or secondly, if the pointer is not at a + // multiple of the entry size. The wording of the RFC suggests that + // this is not an error until you actually run out of space. + if pointer > optLen { + // RFC 791 (page 22) says we should switch to using the overflow count. + // If the timestamp data area is already full (the pointer exceeds + // the length) the datagram is forwarded without inserting the + // timestamp, but the overflow count is incremented by one. + if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { + // By definition we have nothing to do. + return 0, nil + } + + if tsOpt.IncOverflow() != 0 { + return 0, nil + } + // The overflow count is also full. + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + } + if nextSlot+entrySize > dataLength { + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. + if false { + // We must select Pointer for Linux compatibility, even if + // only the length is bad. + // The Linux code is at (in October 2020) + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + // which doesn't distinguish between which of optptr[2] or optlen + // is wrong, but just arbitrarily decides on optptr+2. + if dataLength%entrySize != 0 { + // The Data section size should be a multiple of the expected + // timestamp entry size. + return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + } + // If the size is OK, the pointer must be corrupted. + } + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } + + if usage.actions().timestamp == optionProcess { + tsOpt.UpdateTimestamp(localAddress, clock) + } + return 0, nil +} + +var ( + errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") + errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") +) + +// handleRecordRoute checks and processes a Record route option. It is much +// like the timestamp type 1 option, but without timestamps. The passed in +// address is stored in the option in the correct spot if possible. +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { + optlen := rrOpt.Size() + + if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + + nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + + // RFC 791 page 21 says + // If the route data area is already full (the pointer exceeds the + // length) the datagram is forwarded without inserting the address + // into the recorded route. If there is some room but not enough + // room for a full address to be inserted, the original datagram is + // considered to be in error and is discarded. In either case an + // ICMP parameter problem message may be sent to the source + // host. + // The use of the words "In either case" suggests that a 'full' RR option + // could generate an ICMP at every hop after it fills up. We chose to not + // do this (as do most implementations). It is probable that the inclusion + // of these words is a copy/paste error from the timestamp option where + // there are two failure reasons given. + if nextSlot >= optlen { + return 0, nil + } + + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. We must select Pointer for Linux + // compatibility, even if only the length is bad. + if nextSlot+header.IPv4AddressSize > optlen { + if false { + // This is what we would do if we were not being Linux compatible. + // Check for bad pointer or length value. Must be a multiple of 4 after + // accounting for the 3 byte header and not within that header. + // RFC 791, page 20 says: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // + // A recorded route is composed of a series of internet addresses. + // Each internet address is 32 bits or 4 octets. + // Linux skips this test so we must too. See Linux code at: + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { + // Length is bad, not on integral number of slots. + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + // If not length, the fault must be with the pointer. + } + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } + if usage.actions().recordRoute == optionVerify { + return 0, nil + } + rrOpt.StoreAddress(localAddress) + return 0, nil +} + +// processIPOptions parses the IPv4 options and produces a new set of options +// suitable for use in the next step of packet processing as informed by usage. +// The original will not be touched. +// +// Returns +// - The location of an error if there was one (or 0 if no error) +// - If there is an error, information as to what it was was. +// - The replacement option set. +func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + + opts := header.IPv4Options(orig) + optIter := opts.MakeIterator() + + // Each option other than NOP must only appear (RFC 791 section 3.1, at the + // definition of every type). Keep track of each of the possible types in + // the 8 bit 'type' field. + var seenOptions [math.MaxUint8 + 1]bool + + // TODO(gvisor.dev/issue/4586): + // This will need tweaking when we start really forwarding packets + // as we may need to get two addresses, for rx and tx interfaces. + // We will also have to take usage into account. + prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber) + localAddress := prefixedAddress.Address + if err != nil { + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + } + localAddress = r.LocalAddress + } + + for { + option, done, err := optIter.Next() + if done || err != nil { + return optIter.ErrCursor, optIter.Finalize(), err + } + optType := option.Type() + if optType == header.IPv4OptionNOPType { + optIter.PushNOPOrEnd(optType) + continue + } + if optType == header.IPv4OptionListEndType { + optIter.PushNOPOrEnd(optType) + return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + } + + // check for repeating options (multiple NOPs are OK) + if seenOptions[optType] { + return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + } + seenOptions[optType] = true + + optLen := int(option.Size()) + switch option := option.(type) { + case *header.IPv4OptionTimestamp: + r.Stats().IP.OptionTSReceived.Increment() + if usage.actions().timestamp != optionRemove { + clock := r.Stack().Clock() + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + case *header.IPv4OptionRecordRoute: + r.Stats().IP.OptionRRReceived.Increment() + if usage.actions().recordRoute != optionRemove { + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + default: + r.Stats().IP.OptionUnknownReceived.Increment() + if usage.actions().unknown == optionPass { + newBuffer := optIter.RemainingBuffer()[:optLen] + // Arguments already heavily checked.. ignore result. + _ = copy(newBuffer, option.Contents()) + optIter.ConsumeBuffer(optLen) + } + } + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 819b5d71f..61672a5ff 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -21,11 +21,13 @@ import ( "math" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -39,13 +41,17 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +const ( + extraHeaderReserve = 50 + defaultMTU = 65536 +) + func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { ep = sniffer.New(ep) @@ -101,7 +107,6 @@ func TestExcludeBroadcast(t *testing.T) { // checks the response. func TestIPv4Sanity(t *testing.T) { const ( - defaultMTU = header.IPv6MinimumMTU ttl = 255 nicID = 1 randomSequence = 123 @@ -116,23 +121,34 @@ func TestIPv4Sanity(t *testing.T) { ) tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - shouldFail bool - expectICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - options []byte + name string + headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + options []byte + replyOptions []byte // if succeeds, reply should look like this + shouldFail bool + expectErrorICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + paramProblemPointer uint8 }{ { - name: "valid", - maxTotalLength: defaultMTU, + name: "valid no options", + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, + { + name: "bad header checksum", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + badHeaderChecksum: true, + shouldFail: true, + }, // The TTL tests check that we are not rejecting an incoming packet // with a zero or one TTL, which has been a point of confusion in the // past as RFC 791 says: "If this field contains the value zero, then the @@ -146,47 +162,47 @@ func TestIPv4Sanity(t *testing.T) { // received with TTL less than 2. { name: "zero TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 0, - shouldFail: false, }, { name: "one TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 1, - shouldFail: false, }, { name: "End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{0, 0, 0, 0}, + replyOptions: []byte{0, 0, 0, 0}, }, { name: "NOP options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 1, 1}, + replyOptions: []byte{1, 1, 1, 1}, }, { name: "NOP and End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 0, 0}, + replyOptions: []byte{1, 1, 0, 0}, }, { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (0)", @@ -194,7 +210,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip - 1)", @@ -202,7 +217,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip + icmp - 1)", @@ -210,28 +224,361 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad protocol", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: 99, TTL: ttl, shouldFail: true, - expectICMP: true, + expectErrorICMP: true, ICMPType: header.ICMPv4DstUnreachable, ICMPCode: header.ICMPv4ProtoUnreachable, }, + { + name: "timestamp option overflow", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + }, + { + name: "timestamp option overflow full", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0xF1, + // ^ Counter full (15/0xF) + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + replyOptions: []byte{}, + }, + { + name: "unknown option", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{10, 4, 9, 0}, + // ^^ + // The unknown option should be stripped out of the reply. + replyOptions: []byte{}, + }, + { + name: "bad option - length 0", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 0, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + name: "bad option - length big", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 9, 9, 0, + // ^ + // There are only 8 bytes allocated to options so 9 bytes of timestamp + // space is not possible. (Second byte) + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + // This tests for some linux compatible behaviour. + // The ICMP pointer returned is 22 for Linux but the + // error is actually in spot 21. + name: "bad option - length bad", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + // Timestamps are in multiples of 4 or 8 but never 7. + // The option space should be padded out. + options: []byte{ + 68, 7, 5, 0, + // ^ ^ Linux points here which is wrong. + // | Not a multiple of 4 + 1, 2, 3, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "multiple type 0 with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 24, 21, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 24, 25, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // The timestamp area is full so add to the overflow count. + name: "multiple type 1 timestamps", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 21, 0x11, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + // Overflow count is the top nibble of the 4th byte. + replyOptions: []byte{ + 68, 20, 21, 0x21, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + }, + { + name: "multiple type 1 timestamps with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 28, 21, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 28, 29, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 192, 168, 1, 58, // New IP Address. + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // Needs 8 bytes for a type 1 timestamp but there are only 4 free. + name: "bad timer element alignment", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 17, 0x01, + // ^^ ^^ 20 byte area, next free spot at 17. + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + // End of option list with illegal option after it, which should be ignored. + { + name: "end of options list", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 10, 3, 99, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, // 3 bytes unknown option + }, // ^ End of options hides following bytes. + }, + { + // Timestamp with a size too small. + name: "timestamp truncated", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{68, 1, 0, 0}, + // ^ Smallest possible is 8. + shouldFail: true, + }, + { + name: "single record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 4, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "multiple record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 20, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "single record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Unlike timestamp, this should just succeed. + name: "multiple record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 24, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm linux bug for bug compatibility. + // Linux returns slot 22 but the error is in slot 21. + name: "multiple record route with not enough room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 8, 8, // 3 byte header + // ^ ^ Linux points here. We must too. + // | Not enough room. 1 byte free, need 4. + 1, 2, 3, 4, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + replyOptions: []byte{}, + }, + { + name: "duplicate record route", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, 0, // pad + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 7, + replyOptions: []byte{}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, }) // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") + e := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -239,6 +586,9 @@ func TestIPv4Sanity(t *testing.T) { if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } + // Advance the clock by some unimportant amount to make + // sure it's all set up. + clock.Advance(time.Millisecond * 0x10203040) // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. @@ -288,6 +638,12 @@ func TestIPv4Sanity(t *testing.T) { if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) } + ip.SetChecksum(0) + ipHeaderChecksum := ip.CalculateChecksum() + if test.badHeaderChecksum { + ipHeaderChecksum += 42 + } + ip.SetChecksum(^ipHeaderChecksum) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) @@ -295,14 +651,20 @@ func TestIPv4Sanity(t *testing.T) { reply, ok := e.Read() if !ok { if test.shouldFail { - if test.expectICMP { - t.Fatal("expected ICMP error response missing") + if test.expectErrorICMP { + t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) } return // Expected silent failure. } t.Fatal("expected ICMP echo reply missing") } + // We didn't expect a packet. Register our surprise but carry on to + // provide more information about what we got. + if test.shouldFail && !test.expectErrorICMP { + t.Error("unexpected packet response") + } + // Check the route that brought the packet to us. if reply.Route.LocalAddress != ipv4Addr.Address { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) @@ -311,57 +673,90 @@ func TestIPv4Sanity(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) } - // Make sure it's all in one buffer. - vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) - replyIPHeader := header.IPv4(vv.ToView()) + // Make sure it's all in one buffer for checker. + replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - // At this stage we only know it's an IP header so verify that much. + // At this stage we only know it's probably an IP+ICMP header so verify + // that much. checker.IPv4(t, replyIPHeader, checker.SrcAddr(ipv4Addr.Address), checker.DstAddr(remoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Checksum(), + ), ) - // All expected responses are ICMP packets. - if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { - t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + // Don't proceed any further if the checker found problems. + if t.Failed() { + t.FailNow() } - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - // Sanity check the response. + // OK it's ICMP. We can safely look at the type now. + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) switch replyICMPHeader.Type() { - case header.ICMPv4DstUnreachable: + case header.ICMPv4ParamProblem: + if !test.shouldFail { + t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) + } + if !test.expectErrorICMP { + t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) + } checker.IPv4(t, replyIPHeader, checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Checksum(), + checker.ICMPv4Pointer(test.paramProblemPointer), checker.ICMPv4Payload([]byte(hdr.View())), ), ) - if !test.shouldFail || !test.expectICMP { - t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + return + case header.ICMPv4DstUnreachable: + if !test.shouldFail { + t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + if !test.expectErrorICMP { + t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", header.ICMPv4DstUnreachable, replyICMPHeader.Code()) } + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) return case header.ICMPv4EchoReply: + if test.shouldFail { + if !test.expectErrorICMP { + t.Error("got Echo Reply packet, want no response") + } else { + t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) + } + } + // If the IP options change size then the packet will change size, so + // some IP header fields will need to be adjusted for the checks. + sizeChange := len(test.replyOptions) - len(test.options) + checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength), - checker.IPv4Options(test.options), - checker.IPFullLength(uint16(requestPkt.Size())), + checker.IPv4HeaderLength(ipHeaderLength+sizeChange), + checker.IPv4Options(test.replyOptions), + checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), checker.ICMPv4( + checker.ICMPv4Checksum(), checker.ICMPv4Code(header.ICMPv4UnusedCode), checker.ICMPv4Seq(randomSequence), checker.ICMPv4Ident(randomIdent), - checker.ICMPv4Checksum(), ), ) - if test.shouldFail { - t.Fatalf("unexpected Echo Reply packet\n") - } default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) } }) } @@ -369,7 +764,7 @@ func TestIPv4Sanity(t *testing.T) { // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32) error { +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { // Make a complete array of the sourcePacket packet. source := header.IPv4(packets[0].NetworkHeader().View()) vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) @@ -381,7 +776,6 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB sourceCopy.SetChecksum(0) sourceCopy.SetFlagsFragmentOffset(0, 0) sourceCopy.SetTotalLength(0) - var offset uint16 // Build up an array of the bytes sent. var reassembledPayload buffer.VectorisedView for i, packet := range packets { @@ -391,35 +785,38 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) } - if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { - return fmt.Errorf("fragment #%d: fragmentIPHeader.CalculateChecksum() got %#x, want %#x", i, got, want) - } if got := len(fragmentIPHeader); got > int(mtu) { return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) } - if got, want := packet.AvailableHeaderBytes(), sourcePacket.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { - return fmt.Errorf("fragment #%d: should have the same available space for prepending as source: got %d, want %d", i, got, want) + if got := fragmentIPHeader.TransportProtocol(); got != proto { + return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) + } + if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) } if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { - return fmt.Errorf("fragment #%d: has wrong network protocol number: got %d, want %d", i, got, want) + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) + } + if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { + return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) } - if i < len(packets)-1 { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) + if wantFragments[i].more { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) } reassembledPayload.AppendView(packet.TransportHeader().View()) reassembledPayload.Append(packet.Data) - offset += fragmentIPHeader.TotalLength() - uint16(fragmentIPHeader.HeaderLength()) // Clear out the checksum and length from the ip because we can't compare // it. - sourceCopy.SetTotalLength(uint16(len(fragmentIPHeader))) + sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { - return fmt.Errorf("fragment #%d: fragmentIPHeader[:fragmentIPHeader.HeaderLength()] mismatch (-want +got):\n%s", i, diff) + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } } + expected := buffer.View(source[source.HeaderLength():]) if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) @@ -428,38 +825,122 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB return nil } -func TestFragmentation(t *testing.T) { - const ttl = 42 +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 +} - var manyPayloadViewsSizes [1000]int - for i := range manyPayloadViewsSizes { - manyPayloadViewsSizes[i] = 7 - } - fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - transportHeaderLength int - extraHeaderReserveLength int - payloadViewsSizes []int - expectedFrags int - }{ - {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, - {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, - } +var fragmentationTests = []struct { + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + payloadSize int + wantFragments []fragmentInfo +}{ + { + description: "No fragmentation", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 744, more: false}, + }, + }, + { + description: "Fragmented with the minimum mtu", + mtu: header.IPv4MinimumMTU, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv4MinimumMTU + 1, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso none", + mtu: 1280, + gso: &stack.GSO{Type: stack.GSONone}, + transportHeaderLength: 0, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 144, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: 1280, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 44, more: false}, + }, + }, + { + description: "Fragmented with MTU smaller than header", + mtu: 300, + gso: nil, + transportHeaderLength: 1000, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 280, more: true}, + {offset: 280, payloadSize: 280, more: true}, + {offset: 560, payloadSize: 280, more: true}, + {offset: 840, payloadSize: 280, more: true}, + {offset: 1120, payloadSize: 280, more: true}, + {offset: 1400, payloadSize: 100, more: false}, + }, + }, +} + +func TestFragmentationWritePacket(t *testing.T) { + const ttl = 42 - for _, ft := range fragTests { + for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -469,60 +950,149 @@ func TestFragmentation(t *testing.T) { if err != nil { t.Fatalf("r.WritePacket(_, _, _) = %s", err) } - - if got := len(ep.WrittenPackets); got != ft.expectedFrags { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) } if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - if err := compareFragments(ep.WrittenPackets, source, ft.mtu); err != nil { + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { t.Error(err) } }) } } -// TestFragmentationErrors checks that errors are returned from write packet +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + writePacketsTests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, + } + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + + for _, test := range writePacketsTests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if err != nil { + t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) + } + if n != wantTotalPackets { + t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } + }) + } +} + +// TestFragmentationErrors checks that errors are returned from WritePacket // correctly. func TestFragmentationErrors(t *testing.T) { const ttl = 42 - expectedError := tcpip.ErrAborted - fragTests := []struct { + tests := []struct { description string mtu uint32 transportHeaderLength int payloadSize int allowPackets int - fragmentCount int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error }{ { description: "No frag", mtu: 2000, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 1, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 3, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on second frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 1, - fragmentCount: 3, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag MTU smaller than header", @@ -530,28 +1100,40 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 1000, payloadSize: 500, allowPackets: 0, - fragmentCount: 4, + outgoingErrors: 4, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error when MTU is smaller than IPv4 minimum MTU", + mtu: header.IPv4MinimumMTU - 1, + transportHeaderLength: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, }, } - for _, ft := range fragTests { + for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != expectedError { - t.Errorf("got WritePacket() = %s, want = %s", err, expectedError) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) } - if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) } }) } @@ -583,7 +1165,6 @@ func TestInvalidFragments(t *testing.T) { autoChecksum bool // if true, the Checksum field will be overwritten. } - // These packets have both IHL and TotalLength set to 0. tests := []struct { name string fragments []fragmentData @@ -823,7 +1404,6 @@ func TestInvalidFragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -866,6 +1446,259 @@ func TestInvalidFragments(t *testing.T) { } } +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 99 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + }, + Clock: clock, + }) + e := channel.New(1, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + hdr := buffer.NewPrependable(pktSize) + + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && ip.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(header.IPv4ProtocolNumber, pkt) + } + + clock.Advance(ipv4.ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(firstFragmentSent)), + ), + ) + }) + } +} + // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { @@ -1281,6 +2114,7 @@ func TestReceiveFragments(t *testing.T) { SrcAddr: frag.srcAddr, DstAddr: frag.dstAddr, }) + ip.SetChecksum(^ip.CalculateChecksum()) vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) @@ -1415,7 +2249,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -1549,6 +2383,7 @@ func TestPacketQueing(t *testing.T) { SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -1592,6 +2427,7 @@ func TestPacketQueing(t *testing.T) { SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -1619,7 +2455,7 @@ func TestPacketQueing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a30437f02..0ac24a6fb 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -36,6 +36,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ead6bedcb..3c15e41a7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -170,8 +170,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -284,7 +287,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) } @@ -555,7 +558,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } } @@ -605,7 +608,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } e.mu.Lock() @@ -648,52 +651,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the - // route here. Note, we would need the nicID to do this properly so the right - // NIC (associated to linkEP) is used to send the NDP NS message. - r := stack.Route{ - LocalAddress: localAddr, - RemoteAddress: addr, - LocalLinkAddress: linkEP.LinkAddress(), - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + remoteAddr := targetAddr + if len(remoteLinkAddr) == 0 { + remoteAddr = header.SolicitedNodeAddr(targetAddr) + remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - // If a remote address is not already known, then send a multicast - // solicitation since multicast addresses have a static mapping to link - // addresses. - if len(r.RemoteLinkAddress) == 0 { - r.RemoteAddress = header.SolicitedNodeAddr(addr) - r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) + r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err } + defer r.Release() + r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize, + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) ns := header.NDPNeighborSolicit(packet.NDPPayload()) - ns.SetTargetAddress(addr) + ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) + stat := p.stack.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt); err != nil { + stat.Dropped.Increment() + return err + } - // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt) + stat.NeighborSolicit.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -747,6 +744,13 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { @@ -839,7 +843,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload := network.ToVectorisedView() + payload.AppendView(transport) + payload.Append(pkt.Data) payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -860,6 +866,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) + counter = sent.TimeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8dc33c560..aa8b5f2e5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -51,6 +51,7 @@ const ( var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) + lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -108,31 +109,27 @@ type stubNUDHandler struct { var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { s.probeCount++ } -func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { s.confirmationCount++ } -func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { +func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { } var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - stack.NetworkLinkEndpoint - - linkAddr tcpip.LinkAddress -} + stack.LinkEndpoint -func (i *testInterface) LinkAddress() tcpip.LinkAddress { - return i.linkAddr + nicID tcpip.NICID } func (*testInterface) ID() tcpip.NICID { - return 0 + return nicID } func (*testInterface) IsLoopback() bool { @@ -147,6 +144,14 @@ func (*testInterface) Enabled() bool { return true } +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -1235,26 +1240,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + snaddr := header.SolicitedNodeAddr(lladdr0) mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) tests := []struct { - name string - remoteLinkAddr tcpip.LinkAddress - expectedLinkAddr tcpip.LinkAddress - expectedAddr tcpip.Address + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedRemoteAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", - remoteLinkAddr: linkAddr1, - expectedLinkAddr: linkAddr1, - expectedAddr: lladdr0, + name: "Unicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, + }, + { + name: "Multicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, }, { - name: "Multicast", - remoteLinkAddr: "", - expectedLinkAddr: mcaddr, - expectedAddr: snaddr, + name: "Multicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Unicast with no local address available", + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, }, } @@ -1269,26 +1320,43 @@ func TestLinkAddressRequest(t *testing.T) { } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } + } + + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return } pkt, ok := linkEP.Read() if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } - if pkt.Route.RemoteAddress != test.expectedAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr) + if pkt.Route.RemoteAddress != test.expectedRemoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) } if pkt.Route.LocalAddress != lladdr1 { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) } checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedAddr), + checker.DstAddr(test.expectedRemoteAddr), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(lladdr0), @@ -1698,7 +1766,7 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 2bd8f4ece..1e38f3a9d 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -41,12 +41,12 @@ const ( // // Linux also uses 60 seconds for reassembly timeout: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 - reassembleTimeout = 60 * time.Second + ReassembleTimeout = 60 * time.Second // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber - // maxTotalSize is maximum size that can be encoded in the 16-bit + // maxPayloadSize is the maximum size that can be encoded in the 16-bit // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff @@ -363,7 +363,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and @@ -386,27 +390,40 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. The transport -// header protocol number is required to avoid parsing the IPv6 extension -// headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - if fragMTU < pkt.TransportHeader().View().Size() { +// original packet. The transport header protocol number is required to avoid +// parsing the IPv6 extension headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // maximum payload length because they should only be transmitted once. + fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7 + if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit { + // We need at least 8 bytes of space left for the fragmentable part because + // the fragment payload must obviously be non-zero and must be a multiple + // of 8 as per RFC 8200 section 4.5: + // Each complete fragment, except possibly the last ("rightmost") one, is + // an integer multiple of 8 octets long. + return 0, 1, tcpip.ErrMessageTooLong + } + + if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. return 0, 1, tcpip.ErrMessageTooLong } - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) - networkHeader := header.IPv6(pkt.NetworkHeader().View()) var n int for { @@ -416,17 +433,18 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p } n++ if !more { - break + return n, pf.RemainingFragmentCount(), nil } } - - return n, 0, nil } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt, params.Protocol) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -467,8 +485,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -498,24 +522,30 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) - if e.packetMustBeFragmented(pb, gso) { - current := pb - _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + + networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + if packetMustBeFragmented(pb, networkMTU, gso) { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pb + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. - pkts.InsertAfter(current, fragPkt) - current = current.Next() + pkts.InsertAfter(pb, fragPkt) + pb = fragPkt return nil - }) - if err != nil { + }); err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } - // The fragmented packet can be released. The rest of the packets can be - // processed. - pkts.Remove(pb) - pb = current + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) } } @@ -569,11 +599,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet -// supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - // TODO(b/146666412): Support IPv6 header-included packets. - return tcpip.ErrNotSupported +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { + // The packet already has an IP header, but there are a few required checks. + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return tcpip.ErrMalformedHeader + } + ip := header.IPv6(h) + + // Always set the payload length. + pktSize := pkt.Data.Size() + ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + + // Set the source address when zero. + if ip.SourceAddress() == header.IPv6Any { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. + ip.SetDestinationAddress(r.RemoteAddress) + + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + proto, _, _, _, ok := parse.IPv6(pkt) + if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader + } + + return e.writePacket(r, nil /* gso */, pkt, proto) } // HandlePacket is called by the link layer when new ipv6 packets arrive for @@ -718,6 +777,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { continue } + fragmentFieldOffset := it.ParseOffset() + // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. @@ -775,17 +836,59 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } + // As per RFC 2460 Section 4.5: + // + // If the length of a fragment, as derived from the fragment packet's + // Payload Length field, is not a multiple of 8 octets and the M flag + // of that fragment is 1, then that fragment must be discarded and an + // ICMP Parameter Problem, Code 0, message should be sent to the source + // of the fragment, pointing to the Payload Length field of the + // fragment packet. + if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: header.IPv6PayloadLenOffset, + }, pkt) + return + } + // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - // Drop the fragment if the size of the reassembled payload would exceed - // the maximum payload size. + // As per RFC 2460 Section 4.5: + // + // If the length and offset of a fragment are such that the Payload + // Length of the packet reassembled from that fragment would exceed + // 65,535 octets, then that fragment must be discarded and an ICMP + // Parameter Problem, Code 0, message should be sent to the source of + // the fragment, pointing to the Fragment Offset field of the fragment + // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: fragmentFieldOffset, + }, pkt) return } + // Set up a callback in case we need to send a Time Exceeded Message as + // per RFC 2460 Section 4.5. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + r := r.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + } + r.Release() + } + } + // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. data, proto, ready, err := e.protocol.fragmentation.Process( @@ -801,6 +904,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.More(), uint8(rawPayload.Identifier), rawPayload.Buf, + releaseCB, ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -1398,14 +1502,31 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - mtu -= header.IPv6MinimumSize - if mtu <= maxPayloadSize { - return mtu +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload MTU and the length of every IPv6 header. +// Note that this is different than the Payload Length field of the IPv6 header, +// which includes the length of the extension headers. +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv6MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState + } + + // As per RFC 7112 section 5, we should discard packets if their IPv6 header + // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not + // support PMTU discovery: + // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain + // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280 + // bytes ensures that the header chain length does not exceed the IPv6 + // minimum MTU. + if networkHeadersLen > header.IPv6MinimumMTU { + return 0, tcpip.ErrMalformedHeader + } + + networkMTU := linkMTU - uint32(networkHeadersLen) + if networkMTU > maxPayloadSize { + networkMTU = maxPayloadSize } - return maxPayloadSize + return networkMTU, nil } // Options holds options to configure a new protocol. @@ -1459,7 +1580,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), ids: ids, hashIV: hashIV, @@ -1480,23 +1601,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return NewProtocolWithOptions(Options{})(s) } -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are - // supported for outbound packets, their length should not affect the fragment - // MTU because they should only be transmitted once. - mtu -= uint32(pkt.NetworkHeader().View().Size()) - mtu -= header.IPv6FragmentHeaderSize - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - if mtu <= maxPayloadSize { - return mtu - } - return maxPayloadSize -} - func calculateFragmentReserve(pkt *stack.PacketBuffer) int { return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index bee18d1a8..c593c0004 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" @@ -49,6 +50,8 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + + extraHeaderReserve = 50 ) // testReceiveICMP tests receiving an ICMP packet from src to dst. want is the @@ -181,6 +184,9 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } + if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) + } if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) } @@ -208,8 +214,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB reassembledPayload.Append(fragment.Data) } - result := reassembledPayload.ToView() - if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" { + if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } @@ -234,7 +239,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) + e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } @@ -267,7 +272,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -821,7 +826,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1840,7 +1845,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1908,16 +1913,19 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - nicID = 1 - fragmentExtHdrLen = 8 + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_INVALID_IPV6_FRAGMENTS" ) - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte } tests := []struct { @@ -1925,31 +1933,64 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 + expectICMP bool + expectICMPType header.ICMPv6Type + expectICMPCode header.ICMPv6Code + expectICMPTypeSpecific uint32 }{ { + name: "fragment size is not a multiple of 8 and the M flag is true", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 9, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0 >> 3, + M: true, + Identification: ident, + }, + payload: []byte(data)[:9], + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6PayloadLenOffset, + }, + { name: "fragments reassembled into a payload exceeding the max IPv6 payload size", fragments: []fragmentData{ { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, - []buffer.View{ - // Fragment extension header. - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, - ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, - 0, 0, 0, 1}), - // Payload length = 16 - payloadGen(16), - }, - ), + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, + M: false, + Identification: ident, + }, + payload: []byte(data)[:16], }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ }, } @@ -1960,33 +2001,40 @@ func TestInvalidIPv6Fragments(t *testing.T) { NewProtocol, }, }) - e := channel.New(0, 1500, linkAddr1) + e := channel.New(1, 1500, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + var expectICMPPayload buffer.View for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() - vv.Append(f.data) + vv.AppendView(f.payload) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - })) + }) + + if test.expectICMP { + expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { @@ -1995,6 +2043,287 @@ func TestInvalidIPv6Fragments(t *testing.T) { if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) } + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), + checker.ICMPv6( + checker.ICMPv6Type(test.expectICMPType), + checker.ICMPv6Code(test.expectICMPCode), + checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), + checker.ICMPv6Payload([]byte(expectICMPPayload)), + ), + ) + }) + } +} + +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + Clock: clock, + }) + + e := channel.New(1, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) + } + + clock.Advance(ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), + checker.ICMPv6Payload([]byte(firstFragmentSent)), + ), + ) }) } } @@ -2217,24 +2546,19 @@ type fragmentInfo struct { payloadSize uint16 } -type fragmentationTestCase struct { +var fragmentationTests = []struct { description string mtu uint32 gso *stack.GSO transHdrLen int - extraHdrLen int payloadSize int wantFragments []fragmentInfo - expectedFrags int -} - -var fragmentationTests = []fragmentationTestCase{ +}{ { - description: "No Fragmentation", - mtu: 1280, - gso: &stack.GSO{}, + description: "No fragmentation", + mtu: header.IPv6MinimumMTU, + gso: nil, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1000, more: false}, @@ -2242,10 +2566,20 @@ var fragmentationTests = []fragmentationTestCase{ }, { description: "Fragmented", - mtu: 1280, - gso: &stack.GSO{}, + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv6MinimumMTU + 1, + gso: nil, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 2000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, @@ -2255,20 +2589,18 @@ var fragmentationTests = []fragmentationTestCase{ { description: "No fragmentation with big header", mtu: 2000, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 100, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1100, more: false}, }, }, { - description: "Fragmented with gso nil", - mtu: 1280, - gso: nil, + description: "Fragmented with gso none", + mtu: header.IPv6MinimumMTU, + gso: &stack.GSO{Type: stack.GSONone}, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1400, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, @@ -2277,31 +2609,18 @@ var fragmentationTests = []fragmentationTestCase{ }, { description: "Fragmented with big header", - mtu: 1280, - gso: &stack.GSO{}, + mtu: header.IPv6MinimumMTU, + gso: nil, transHdrLen: 100, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1200, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, {offset: 154, payloadSize: 76, more: false}, }, }, - { - description: "Fragmented with big header and prependable bytes", - mtu: 1280, - gso: &stack.GSO{}, - transHdrLen: 20, - extraHdrLen: header.IPv6MinimumSize + 66, - payloadSize: 1500, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 296, more: false}, - }, - }, } -func TestFragmentation(t *testing.T) { +func TestFragmentationWritePacket(t *testing.T) { const ( ttl = 42 tos = stack.DefaultTOS @@ -2310,7 +2629,7 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt.Clone() ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) @@ -2331,10 +2650,8 @@ func TestFragmentation(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - if len(ep.WrittenPackets) > 0 { - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) } }) } @@ -2368,7 +2685,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -2378,7 +2695,7 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { @@ -2467,8 +2784,8 @@ func TestFragmentationErrors(t *testing.T) { wantError: tcpip.ErrAborted, }, { - description: "Error on packet with MTU smaller than transport header", - mtu: 1280, + description: "Error when MTU is smaller than transport header", + mtu: header.IPv6MinimumMTU, transHdrLen: 1500, payloadSize: 500, allowPackets: 0, @@ -2476,11 +2793,21 @@ func TestFragmentationErrors(t *testing.T) { mockError: nil, wantError: tcpip.ErrMessageTooLong, }, + { + description: "Error when MTU is smaller than IPv6 minimum MTU", + mtu: header.IPv6MinimumMTU - 1, + transHdrLen: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, + }, } for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 4d3acab96..261705575 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -361,6 +361,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * return tcpip.ErrInvalidEndpointState } + a.mu.Lock() + defer a.mu.Unlock() return a.removePermanentEndpointLocked(addrState) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index cf042309e..380688038 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 6f73a0ce4..c9b13cd0e 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { return addr, nil, nil @@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.linkAddr, entry.done, tcpip.ErrWouldBlock @@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 33806340e..d2e37f38d 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -49,8 +49,8 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 4df288798..eebf43a1f 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -68,7 +68,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -84,7 +84,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -111,6 +111,10 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // provided, it will be notified when address resolution is complete (success // or not). // +// If specified, the local address must be an address local to the interface the +// neighbor cache belongs to. The local address is the source address of a +// packet prompting NUD/link address resolution. +// // If address resolution is required, ErrNoLinkAddress and a notification // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). @@ -118,7 +122,6 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ Addr: remoteAddr, - LocalAddr: localAddr, LinkAddr: linkAddr, State: Static, UpdatedAt: time.Now(), @@ -126,13 +129,13 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA return e, nil, nil } - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() switch s := entry.neigh.State; s { case Stale: - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) fallthrough case Reachable, Static, Delay, Probe: // As per RFC 4861 section 7.3.3: @@ -152,7 +155,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress @@ -207,7 +210,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd } else { // Static entry found with the same address but different link address. entry.neigh.LinkAddr = linkAddr - entry.dispatchChangeEventLocked(entry.neigh.State) + entry.dispatchChangeEventLocked() entry.mu.Unlock() return } @@ -220,8 +223,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) - n.cache[addr] = entry + n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } // removeEntryLocked removes the specified entry from the neighbor cache. @@ -292,8 +294,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) +func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index fcd54ed83..d81f00848 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -128,9 +128,8 @@ func newTestEntryStore() *testEntryStore { linkAddr := toLinkAddress(i) store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LocalAddr: testEntryLocalAddr, - LinkAddr: linkAddr, + Addr: addr, + LinkAddr: linkAddr, } } return store @@ -195,10 +194,10 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(addr) + r.fakeRequest(targetAddr) }) // Execute post address resolution action, if available. @@ -294,9 +293,8 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -305,15 +303,19 @@ func TestNeighborCacheEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -324,8 +326,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,9 +356,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -365,15 +367,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -391,9 +397,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -404,8 +412,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -452,8 +460,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -470,23 +478,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, testEntryEventInfo{ EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }) c.nudDisp.mu.Lock() @@ -508,10 +522,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -564,24 +577,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -600,9 +616,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -640,9 +658,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -682,9 +702,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -703,9 +725,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -740,9 +764,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -760,9 +786,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -800,24 +828,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -836,16 +867,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -861,10 +896,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, }, }, } @@ -896,12 +930,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } clock.Advance(typicalLatency) @@ -913,7 +947,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { id, ok := s.Fetch(false /* block */) if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) } if id != wakerID { t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) @@ -923,15 +957,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -964,12 +1002,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } // Remove the waker before the neighbor cache has the opportunity to send a @@ -991,15 +1029,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1028,10 +1070,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) @@ -1041,9 +1082,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -1058,10 +1101,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, }, }, } @@ -1089,9 +1131,8 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1099,15 +1140,19 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1126,9 +1171,11 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1149,16 +1196,20 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestRemoved, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1185,24 +1236,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1220,9 +1274,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1274,29 +1330,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1312,9 +1372,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := neighborCacheSize; i < store.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1322,15 +1381,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1342,22 +1401,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1374,10 +1439,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { - Addr: frequentlyUsedEntry.Addr, - LocalAddr: frequentlyUsedEntry.LocalAddr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, + Addr: frequentlyUsedEntry.Addr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, }, } @@ -1387,10 +1451,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1430,9 +1493,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } @@ -1456,10 +1518,9 @@ func TestNeighborCacheConcurrent(t *testing.T) { t.Errorf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1488,37 +1549,36 @@ func TestNeighborCacheReplace(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) } if t.Failed() { t.FailNow() } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1542,37 +1602,35 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1601,35 +1659,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) - got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } // Verify that address resolution for an unknown address returns ErrNoLinkAddress before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } maxAttempts := neigh.config().MaxUnicastProbes @@ -1659,13 +1716,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } } @@ -1683,18 +1740,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) { delay: typicalLatency, } - got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ - Addr: testEntryBroadcastAddr, - LocalAddr: testEntryLocalAddr, - LinkAddr: testEntryBroadcastLinkAddr, - State: Static, + Addr: testEntryBroadcastAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1719,9 +1775,9 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh != nil { <-doneCh diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index be61a21af..bd80f95bd 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -27,7 +27,6 @@ import ( // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { Addr tcpip.Address - LocalAddr tcpip.Address LinkAddr tcpip.LinkAddress State NeighborState UpdatedAt time.Time @@ -106,35 +105,35 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { return &neighborEntry{ nic: nic, linkRes: linkRes, nudState: nudState, neigh: NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - State: Unknown, + Addr: remoteAddr, + State: Unknown, }, } } -// newStaticNeighborEntry creates a neighbor cache entry starting at the Static -// state. The entry can only transition out of Static by directly calling -// `setStateLocked`. +// newStaticNeighborEntry creates a neighbor cache entry starting at the +// Static state. The entry can only transition out of Static by directly +// calling `setStateLocked`. func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + entry := NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + } if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) } return &neighborEntry{ nic: nic, nudState: state, - neigh: NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - }, + neigh: entry, } } @@ -165,17 +164,17 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. -func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborAdded(e.nic.id, e.neigh) } } // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. -func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborChanged(e.nic.id, e.neigh) } } @@ -183,7 +182,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { // has been removed. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } @@ -206,63 +205,19 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { switch next { case Incomplete: - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } - - sendMulticastProbe() + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Probe) e.setStateLocked(Probe) + e.dispatchChangeEventLocked() }) e.job.Schedule(config.DelayFirstProbeTime) @@ -277,19 +232,13 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -315,15 +264,72 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. -func (e *neighborEntry) handlePacketQueuedLocked() { +func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Unknown: - e.dispatchAddEventLocked(Incomplete) - e.setStateLocked(Incomplete) + e.neigh.State = Incomplete + e.neigh.UpdatedAt = time.Now() + + e.dispatchAddEventLocked() + + config := e.nudState.Config() + + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is the + // same as one of the addresses assigned to the outgoing interface, that + // address SHOULD be placed in the IP Source Address of the outgoing + // solicitation. + // + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendMulticastProbe() case Stale: - e.dispatchChangeEventLocked(Delay) e.setStateLocked(Delay) + e.dispatchChangeEventLocked() case Incomplete, Reachable, Delay, Probe, Static, Failed: // Do nothing @@ -345,21 +351,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { switch e.neigh.State { case Unknown, Incomplete, Failed: e.neigh.LinkAddr = remoteLinkAddr - e.dispatchAddEventLocked(Stale) e.setStateLocked(Stale) e.notifyWakersLocked() + e.dispatchAddEventLocked() case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } case Stale: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) + e.dispatchChangeEventLocked() } case Static: @@ -393,12 +399,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.neigh.LinkAddr = linkAddr if flags.Solicited { - e.dispatchChangeEventLocked(Reachable) e.setStateLocked(Reachable) } else { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) } + e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter e.notifyWakersLocked() @@ -411,8 +416,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } break } @@ -421,23 +426,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if !flags.Solicited { if e.neigh.State != Stale { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } else { // Notify the LinkAddr change, even though NUD state hasn't changed. - e.dispatchChangeEventLocked(e.neigh.State) + e.dispatchChangeEventLocked() } break } } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - } + wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyWakersLocked() + if !wasReachable { + e.dispatchChangeEventLocked() + } } if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { @@ -475,11 +481,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - // Set state to Reachable again to refresh timers. - } + wasReachable := e.neigh.State == Reachable + // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) + if !wasReachable { + e.dispatchChangeEventLocked() + } case Unknown, Incomplete, Failed, Static: // Do nothing diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 3ee2a3b31..e8e0e571b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -52,19 +52,16 @@ const ( // predict the time that an event will be dispatched. func eventDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), } } // eventDiffOptsWithSort is like eventDiffOpts but also includes an option to // sort slices of events for cases where ordering must be ignored. func eventDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + })) } // The following unit tests exercise every state transition and verify its @@ -125,14 +122,11 @@ func (t testEntryEventType) String() string { type testEntryEventInfo struct { EventType testEntryEventType NICID tcpip.NICID - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Entry NeighborEntry } func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) } // testNUDDispatcher implements NUDDispatcher to validate the dispatching of @@ -150,36 +144,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { d.events = append(d.events, e) } -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestAdded, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestChanged, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestRemoved, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } @@ -202,9 +187,9 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { p := entryTestProbeInfo{ - RemoteAddress: addr, + RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, LocalAddress: localAddr, } @@ -245,7 +230,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ @@ -326,7 +311,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -350,9 +335,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } { @@ -388,9 +375,11 @@ func TestEntryUnknownToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -406,7 +395,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -468,16 +457,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -498,7 +491,7 @@ func TestEntryIncompleteToReachable(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -530,16 +523,20 @@ func TestEntryIncompleteToReachable(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -563,7 +560,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { defer s.Done() e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got := e.wakers; got != nil { t.Errorf("got e.wakers = %v, want = nil", got) } @@ -605,16 +602,20 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -629,7 +630,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -663,16 +664,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -687,7 +692,7 @@ func TestEntryIncompleteToStale(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -719,16 +724,20 @@ func TestEntryIncompleteToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -744,7 +753,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -783,16 +792,20 @@ func TestEntryIncompleteToFailed(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -822,7 +835,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -866,16 +879,20 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -896,7 +913,7 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -932,16 +949,20 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -961,7 +982,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -992,23 +1013,29 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1029,7 +1056,7 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -1062,23 +1089,29 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1099,7 +1132,7 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -1136,23 +1169,29 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1173,7 +1212,7 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -1210,23 +1249,29 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1247,7 +1292,7 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1283,16 +1328,20 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1307,7 +1356,7 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1347,23 +1396,29 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1378,7 +1433,7 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1418,23 +1473,29 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1449,7 +1510,7 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1489,23 +1550,29 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1520,7 +1587,7 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1556,23 +1623,29 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1587,7 +1660,7 @@ func TestEntryStaleToDelay(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1596,7 +1669,7 @@ func TestEntryStaleToDelay(t *testing.T) { if got, want := e.neigh.State, Stale; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1620,23 +1693,29 @@ func TestEntryStaleToDelay(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1656,13 +1735,13 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1692,37 +1771,47 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1743,13 +1832,13 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1786,37 +1875,47 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1837,13 +1936,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if e.neigh.State != Delay { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } @@ -1880,37 +1979,47 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1925,13 +2034,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1966,23 +2075,29 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1997,13 +2112,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -2031,30 +2146,38 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2069,13 +2192,13 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -2107,30 +2230,38 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2145,13 +2276,13 @@ func TestEntryDelayToProbe(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -2170,7 +2301,6 @@ func TestEntryDelayToProbe(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2184,30 +2314,38 @@ func TestEntryDelayToProbe(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2228,13 +2366,13 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2250,7 +2388,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2274,37 +2411,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2325,13 +2472,13 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2347,7 +2494,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2375,37 +2521,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2426,13 +2582,13 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2448,7 +2604,6 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2479,30 +2634,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2529,7 +2692,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2539,7 +2702,6 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2572,37 +2734,47 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2622,13 +2794,13 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2644,7 +2816,6 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2677,44 +2848,56 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2734,13 +2917,13 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2756,7 +2939,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2786,44 +2968,56 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2843,13 +3037,13 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2865,7 +3059,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2895,44 +3088,56 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2946,87 +3151,116 @@ func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 + c.DelayFirstProbeTime = c.RetransmitTimer e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + { + wantProbes := []entryTestProbeInfo{ + // Caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.Advance(waitFor) + // Observe each probe sent while in the Probe state. + for i := uint32(0); i < c.MaxUnicastProbes; i++ { + clock.Advance(c.RetransmitTimer) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + } - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.mu.Unlock() } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + + // Wait for the last probe to expire, causing a transition to Failed. + clock.Advance(c.RetransmitTimer) + e.mu.Lock() + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -3034,12 +3268,6 @@ func TestEntryProbeToFailed(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryFailedGetsDeleted(t *testing.T) { @@ -3054,13 +3282,13 @@ func TestEntryFailedGetsDeleted(t *testing.T) { } e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime @@ -3077,17 +3305,14 @@ func TestEntryFailedGetsDeleted(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -3101,37 +3326,47 @@ func TestEntryFailedGetsDeleted(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8828cc5fe..17f2e6b46 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -274,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb return n.writePacket(r, gso, protocol, pkt) } +// WritePacketToRemote implements NetworkInterface. +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + r := Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return n.writePacket(&r, gso, protocol, pkt) +} + func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -679,14 +687,17 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease + // the TTL field for ipv4/ipv6. // pkt may have set its header and may not have enough headroom for // link-layer header for the other link to prepend. Here we create a new // packet to forward. fwdPkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + // We need to do a deep copy of the IP packet because WritePacket (and + // friends) take ownership of the packet buffer, but we do not own it. + Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), }) // TODO(b/143425874) Decrease the TTL field in forwarded packets. diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 97a96af62..4af04846f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -169,7 +169,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index e1ec15487..ab629b3a4 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -129,7 +129,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborAdded(tcpip.NICID, NeighborEntry) // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) // neighbor table changes state and/or link address. @@ -138,7 +138,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborChanged(tcpip.NICID, NeighborEntry) // OnNeighborRemoved will be called when an entry is removed from a NIC's // (with ID nicID) neighbor table. @@ -147,7 +147,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborRemoved(tcpip.NICID, NeighborEntry) } // ReachabilityConfirmationFlags describes the flags used within a reachability @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 105583c49..7f54a6de8 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -311,11 +311,25 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { } // PayloadSince returns packet payload starting from and including a particular -// header. This method isn't optimized and should be used in test only. +// header. +// +// The returned View is owned by the caller - its backing buffer is separate +// from the packet header's underlying packet buffer. func PayloadSince(h PacketHeader) buffer.View { - var v buffer.View + size := h.pk.Data.Size() + for _, hinfo := range h.pk.headers[h.typ:] { + size += len(hinfo.buf) + } + + v := make(buffer.View, 0, size) + for _, hinfo := range h.pk.headers[h.typ:] { v = append(v, hinfo.buf...) } - return append(v, h.pk.Data.ToView()...) + + for _, view := range h.pk.Data.Views() { + v = append(v, view...) + } + + return v } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..203f3b51f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -490,6 +490,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // WritePacketToRemote writes the packet to the given remote link address. + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -764,13 +767,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts - // the request on the local network if remoteLinkAddr is the zero value. The - // request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the link address of the target + // address. The request is broadcasted on the local network if a remote link + // address is not provided. // - // A valid response will cause the discovery protocol's network - // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error + // The request is sent from the passed network interface. If the interface + // local address is unspecified, any interface local address may be used. + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..e8f1c110e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -830,6 +830,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] @@ -1323,7 +1337,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) } // Neighbors returns all IP to MAC address associations. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..4eed4ced4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3672,3 +3672,87 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 62ab6d92f..6b8071467 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -28,7 +28,7 @@ import ( const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 + fakeTransHeaderLen int = 3 ) // fakeTransportEndpoint is a transport-layer protocol endpoint. It counts diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c42bb0991..3ab2b7654 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -111,6 +111,7 @@ var ( ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} ErrNotPermitted = &Error{msg: "operation not permitted"} ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} + ErrMalformedHeader = &Error{msg: "header is malformed"} ) var messageToError map[string]*Error @@ -159,6 +160,7 @@ func StringToError(s string) *Error { ErrBroadcastDisabled, ErrNotPermitted, ErrAddressFamilyNotSupported, + ErrMalformedHeader, } messageToError = make(map[string]*Error) @@ -354,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool { return s.Prefix() <= 30 && s.Broadcast() == address } -// Equal returns true if s equals o. -// -// Needed to use cmp.Equal on Subnet as its fields are unexported. +// Equal returns true if this Subnet is equal to the given Subnet. func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. return s == o } @@ -761,6 +762,10 @@ const ( // endpoint that all packets being written have an IP header and the // endpoint should not attach an IP header. IPHdrIncludedOption + + // AcceptConnOption is used by GetSockOptBool to indicate if the + // socket is a listening socket. + AcceptConnOption ) // SockOptInt represents socket options which values have the int type. @@ -1254,6 +1259,12 @@ func (r Route) String() string { return out.String() } +// Equal returns true if the given Route is equal to this Route. +func (r Route) Equal(to Route) bool { + // NOTE: This relies on the fact that r.Destination == to.Destination + return r == to +} + // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 @@ -1494,6 +1505,15 @@ type IPStats struct { // IPTablesOutputDropped is the total number of IP packets dropped in // the Output chain. IPTablesOutputDropped *StatCounter + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived *StatCounter + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived *StatCounter + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived *StatCounter } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index a4f141253..34aab32d0 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -16,6 +16,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ffd38ee1a..0dcef7b04 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -21,6 +21,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -178,19 +179,19 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) + routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) - if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil { + if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) } - if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil { + if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) } - if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index bf3a6f6ee..6ddcda70c 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -126,12 +127,12 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) - if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 4f2ca7f54..f1028823b 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -80,6 +80,7 @@ func TestPingMulticastBroadcast(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: dst, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -250,6 +251,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: dst, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 41eb0ca44..a17234946 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -378,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil default: diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 072601d2d..31831a6d8 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -389,7 +389,12 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrNotSupported + switch opt { + case tcpip.AcceptConnOption: + return false, nil + default: + return false, tcpip.ErrNotSupported + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e37c00523..79f688129 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -601,7 +601,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.IPHdrIncludedOption: diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 33bfb56cd..7d97cbdc7 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { } // beforeSave is invoked by stateify. -func (ep *endpoint) beforeSave() { +func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming // packets. - ep.rcvMu.Lock() + e.rcvMu.Lock() } // saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 + e.rcvBufSizeMax = 0 // Release the lock acquired in beforeSave() so regular endpoint closing // logic can proceed after save. - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return max } // loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max } // afterLoad is invoked by stateify. -func (ep *endpoint) afterLoad() { - stack.StackFromEnv.RegisterRestoredEndpoint(ep) +func (e *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. -func (ep *endpoint) Resume(s *stack.Stack) { - ep.stack = s +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s // If the endpoint is connected, re-connect. - if ep.connected { + if e.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false) if err != nil { panic(err) } } // If the endpoint is bound, re-bind. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { + if e.bound { + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if ep.associated { - if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + if e.associated { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b706438bd..6b3238d6b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -425,20 +425,17 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer ctx.synRcvdCount.dec() - defer func() { - e.mu.Lock() - e.decSynRcvdCount() - e.mu.Unlock() - }() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() + e.decSynRcvdCount() return } ctx.removePendingEndpoint(n) + e.decSynRcvdCount() n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() @@ -456,7 +453,9 @@ func (e *endpoint) incSynRcvdCount() bool { } func (e *endpoint) decSynRcvdCount() { + e.mu.Lock() e.synRcvdCount-- + e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3bcd3923a..c826942e9 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1999,6 +1999,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.MulticastLoopOption: return true, nil + case tcpip.AcceptConnOption: + e.LockUser() + defer e.UnlockUser() + + return e.EndpointState() == StateListen, nil + default: return false, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go index 7ef2df377..833a7b470 100644 --- a/pkg/tcpip/transport/tcp/sack_scoreboard.go +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { return found } -// Dump prints the state of the scoreboard structure. +// String returns human-readable state of the scoreboard structure. func (s *SACKScoreboard) String() string { var str strings.Builder str.WriteString("SACKScoreboard: {") diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a7149efd0..5f05608e2 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5131,6 +5131,7 @@ func TestKeepalive(t *testing.T) { } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendPacket(nil, &context.Headers{ @@ -5175,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki } func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendV6Packet(nil, &context.Headers{ @@ -5238,13 +5240,14 @@ func TestListenBacklogFull(t *testing.T) { // Test acceptance. // Start listening. - listenBacklog := 2 + listenBacklog := 10 if err := c.EP.Listen(listenBacklog); err != nil { t.Fatalf("Listen failed: %s", err) } - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) + lastPortOffset := uint16(0) + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } time.Sleep(50 * time.Millisecond) @@ -5252,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, + SrcPort: context.TestPort + uint16(lastPortOffset), DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(789), @@ -5293,7 +5296,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { @@ -6722,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + // drain any older notifications from the notification channel before attempting + // 2nd connection. + select { + case <-ch: + default: + } + // Send a SYN request w/ sequence number higher than // the highest sequence number sent. iss = seqnum.Value(792) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 4d7847142..79646fefe 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -373,6 +373,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, const icmpv4VariableHeaderOffset = 4 copy(icmp[icmpv4VariableHeaderOffset:], p1) copy(icmp[header.ICMPv4PayloadOffset:], p2) + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) // Inject packet. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d57ed5d79..cdb5127ab 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -895,6 +895,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil + case tcpip.AcceptConnOption: + return false, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -1366,6 +1369,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.rcvMu.Unlock() } + e.lastErrorMu.Lock() + hasError := e.lastError != nil + e.lastErrorMu.Unlock() + if hasError { + result |= waiter.EventErr + } return result } @@ -1465,14 +1474,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() - defer e.mu.RUnlock() - if e.state == StateConnected { e.lastErrorMu.Lock() - defer e.lastErrorMu.Unlock() - e.lastError = tcpip.ErrConnectionRefused + e.lastErrorMu.Unlock() + e.mu.RUnlock() + + e.waiterQueue.Notify(waiter.EventErr) + return } + e.mu.RUnlock() } } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b4604ba35..fb7738dda 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1452,6 +1452,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1791,7 +1795,6 @@ func TestV4UnknownDestination(t *testing.T) { // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - } // In the case of large payloads the IP packet may be truncated. Update diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 67a950444..08519d986 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -168,7 +168,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { // // +stateify savable type Queue struct { - list waiterList `state:"zerovalue"` + list waiterList mu sync.RWMutex `state:"nosave"` } |