diff options
Diffstat (limited to 'pkg')
516 files changed, 31173 insertions, 12195 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 4a26e28de..a0654df2f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -55,6 +55,8 @@ go_library( "sched.go", "seccomp.go", "sem.go", + "sem_amd64.go", + "sem_arm64.go", "shm.go", "signal.go", "signalfd.go", diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index cc3571fad..d1ca56370 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -25,6 +25,8 @@ const ( F_SETLKW = 7 F_SETOWN = 8 F_GETOWN = 9 + F_SETSIG = 10 + F_GETSIG = 11 F_SETOWN_EX = 15 F_GETOWN_EX = 16 F_DUPFD_CLOEXEC = 1024 + 6 diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index d91c97a64..1070b457c 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -19,16 +19,22 @@ import ( "gvisor.dev/gvisor/pkg/marshal/primitive" ) +// FUSEOpcode is a FUSE operation code. +// // +marshal type FUSEOpcode uint32 +// FUSEOpID is a FUSE operation ID. +// // +marshal type FUSEOpID uint64 // FUSE_ROOT_ID is the id of root inode. const FUSE_ROOT_ID = 1 -// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. +// Opcodes for FUSE operations. +// +// Analogous to the opcodes in include/linux/fuse.h. const ( FUSE_LOOKUP FUSEOpcode = 1 FUSE_FORGET = 2 /* no reply */ diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 7df02dd6d..006b5a525 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -121,6 +121,9 @@ const ( // Constants from uapi/linux/fsverity.h. const ( + FS_VERITY_HASH_ALG_SHA256 = 1 + FS_VERITY_HASH_ALG_SHA512 = 2 + FS_IOC_ENABLE_VERITY = 1082156677 FS_IOC_MEASURE_VERITY = 3221513862 ) diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 487a626cc..0adff8dff 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -32,19 +32,24 @@ const ( SEM_STAT_ANY = 20 ) -const SEM_UNDO = 0x1000 - -// SemidDS is equivalent to struct semid64_ds. +// Information about system-wide sempahore limits and parameters. // -// +marshal -type SemidDS struct { - SemPerm IPCPerm - SemOTime TimeT - SemCTime TimeT - SemNSems uint64 - unused3 uint64 - unused4 uint64 -} +// Source: include/uapi/linux/sem.h +const ( + SEMMNI = 32000 + SEMMSL = 32000 + SEMMNS = SEMMNI * SEMMSL + SEMOPM = 500 + SEMVMX = 32767 + SEMAEM = SEMVMX + + // followings are unused in kernel + SEMUME = SEMOPM + SEMMNU = SEMMNS + SEMMAP = SEMMNS +) + +const SEM_UNDO = 0x1000 // Sembuf is equivalent to struct sembuf. // @@ -54,3 +59,21 @@ type Sembuf struct { SemOp int16 SemFlg int16 } + +// SemInfo is equivalent to struct seminfo. +// +// Source: include/uapi/linux/sem.h +// +// +marshal +type SemInfo struct { + SemMap uint32 + SemMni uint32 + SemMns uint32 + SemMnu uint32 + SemMsl uint32 + SemOpm uint32 + SemUme uint32 + SemUsz uint32 + SemVmx uint32 + SemAem uint32 +} diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go new file mode 100644 index 000000000..ab980cb4f --- /dev/null +++ b/pkg/abi/linux/sem_amd64.go @@ -0,0 +1,33 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: arch/x86/include/uapi/asm/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + unused1 uint64 + SemCTime TimeT + unused2 uint64 + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go new file mode 100644 index 000000000..521468fb1 --- /dev/null +++ b/pkg/abi/linux/sem_arm64.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: include/uapi/asm-generic/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + SemCTime TimeT + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index d156d41e4..556892dc3 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -111,12 +111,12 @@ type SockType int // Socket types, from linux/net.h. const ( SOCK_STREAM SockType = 1 - SOCK_DGRAM = 2 - SOCK_RAW = 3 - SOCK_RDM = 4 - SOCK_SEQPACKET = 5 - SOCK_DCCP = 6 - SOCK_PACKET = 10 + SOCK_DGRAM SockType = 2 + SOCK_RAW SockType = 3 + SOCK_RDM SockType = 4 + SOCK_SEQPACKET SockType = 5 + SOCK_DCCP SockType = 6 + SOCK_PACKET SockType = 10 ) // SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are @@ -448,6 +448,8 @@ type ControlMessageCredentials struct { // A ControlMessageIPPacketInfo is IP_PKTINFO socket control message. // // ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. +// +// +stateify savable type ControlMessageIPPacketInfo struct { NIC int32 LocalAddr InetAddr diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go index 069d0395d..6d1e65cb1 100644 --- a/pkg/bpf/decoder.go +++ b/pkg/bpf/decoder.go @@ -109,7 +109,7 @@ func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error { case B: w.WriteString("1") default: - return fmt.Errorf("Invalid BPF LD size: %v", inst) + return fmt.Errorf("invalid BPF LD size: %v", inst) } return nil } diff --git a/pkg/context/context.go b/pkg/context/context.go index 2613bc752..f3031fc60 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()} func Background() Context { return bgContext } + +// WithValue returns a copy of parent in which the value associated with key is +// val. +func WithValue(parent Context, key, val interface{}) Context { + return &withValue{ + Context: parent, + key: key, + val: val, + } +} + +type withValue struct { + Context + key interface{} + val interface{} +} + +// Value implements Context.Value. +func (ctx *withValue) Value(key interface{}) interface{} { + if key == ctx.key { + return ctx.val + } + return ctx.Context.Value(key) +} diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index a4f4b2c5e..fdfe31417 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -27,6 +27,7 @@ import ( "io" "sort" "sync/atomic" + "testing" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -34,12 +35,6 @@ import ( "github.com/bazelbuild/rules_go/go/tools/coverdata" ) -// KcovAvailable returns whether the kcov coverage interface is available. It is -// available as long as coverage is enabled for some files. -func KcovAvailable() bool { - return len(coverdata.Cover.Blocks) > 0 -} - // coverageMu must be held while accessing coverdata.Cover. This prevents // concurrent reads/writes from multiple threads collecting coverage data. var coverageMu sync.RWMutex @@ -47,6 +42,22 @@ var coverageMu sync.RWMutex // once ensures that globalData is only initialized once. var once sync.Once +// blockBitLength is the number of bits used to represent coverage block index +// in a synthetic PC (the rest are used to represent the file index). Even +// though a PC has 64 bits, we only use the lower 32 bits because some users +// (e.g., syzkaller) may truncate that address to a 32-bit value. +// +// As of this writing, there are ~1200 files that can be instrumented and at +// most ~1200 blocks per file, so 16 bits is more than enough to represent every +// file and every block. +const blockBitLength = 16 + +// KcovAvailable returns whether the kcov coverage interface is available. It is +// available as long as coverage is enabled for some files. +func KcovAvailable() bool { + return len(coverdata.Cover.Blocks) > 0 +} + var globalData struct { // files is the set of covered files sorted by filename. It is calculated at // startup. @@ -104,14 +115,14 @@ var coveragePool = sync.Pool{ // coverage tools, we reset the global coverage data every time this function is // run. func ConsumeCoverageData(w io.Writer) int { - once.Do(initCoverageData) + InitCoverageData() coverageMu.Lock() defer coverageMu.Unlock() total := 0 var pcBuffer [8]byte - for fileIndex, file := range globalData.files { + for fileNum, file := range globalData.files { counters := coverdata.Cover.Counters[file] for index := 0; index < len(counters); index++ { if atomic.LoadUint32(&counters[index]) == 0 { @@ -119,7 +130,7 @@ func ConsumeCoverageData(w io.Writer) int { } // Non-zero coverage data found; consume it and report as a PC. atomic.StoreUint32(&counters[index], 0) - pc := globalData.syntheticPCs[fileIndex][index] + pc := globalData.syntheticPCs[fileNum][index] usermem.ByteOrder.PutUint64(pcBuffer[:], pc) n, err := w.Write(pcBuffer[:]) if err != nil { @@ -142,31 +153,84 @@ func ConsumeCoverageData(w io.Writer) int { return total } -// initCoverageData initializes globalData. It should only be called once, -// before any kcov data is written. -func initCoverageData() { - // First, order all files. Then calculate synthetic PCs for every block - // (using the well-defined ordering for files as well). - for file := range coverdata.Cover.Blocks { - globalData.files = append(globalData.files, file) +// InitCoverageData initializes globalData. It should be called before any kcov +// data is written. +func InitCoverageData() { + once.Do(func() { + // First, order all files. Then calculate synthetic PCs for every block + // (using the well-defined ordering for files as well). + for file := range coverdata.Cover.Blocks { + globalData.files = append(globalData.files, file) + } + sort.Strings(globalData.files) + + for fileNum, file := range globalData.files { + blocks := coverdata.Cover.Blocks[file] + pcs := make([]uint64, 0, len(blocks)) + for blockNum := range blocks { + pcs = append(pcs, calculateSyntheticPC(fileNum, blockNum)) + } + globalData.syntheticPCs = append(globalData.syntheticPCs, pcs) + } + }) +} + +// Symbolize prints information about the block corresponding to pc. +func Symbolize(out io.Writer, pc uint64) error { + fileNum, blockNum := syntheticPCToIndexes(pc) + file, err := fileFromIndex(fileNum) + if err != nil { + return err + } + block, err := blockFromIndex(file, blockNum) + if err != nil { + return err } - sort.Strings(globalData.files) - - // nextSyntheticPC is the first PC that we generate for a block. - // - // This uses a standard-looking kernel range for simplicity. - // - // FIXME(b/160639712): This is only necessary because syzkaller requires - // addresses in the kernel range. If we can remove this constraint, then we - // should be able to use the actual addresses. - var nextSyntheticPC uint64 = 0xffffffff80000000 - for _, file := range globalData.files { - blocks := coverdata.Cover.Blocks[file] - thisFile := make([]uint64, 0, len(blocks)) - for range blocks { - thisFile = append(thisFile, nextSyntheticPC) - nextSyntheticPC++ // Advance. + writeBlock(out, pc, file, block) + return nil +} + +// WriteAllBlocks prints all information about all blocks along with their +// corresponding synthetic PCs. +func WriteAllBlocks(out io.Writer) { + for fileNum, file := range globalData.files { + for blockNum, block := range coverdata.Cover.Blocks[file] { + writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block) } - globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile) } } + +func calculateSyntheticPC(fileNum int, blockNum int) uint64 { + return (uint64(fileNum) << blockBitLength) + uint64(blockNum) +} + +func syntheticPCToIndexes(pc uint64) (fileNum int, blockNum int) { + return int(pc >> blockBitLength), int(pc & ((1 << blockBitLength) - 1)) +} + +// fileFromIndex returns the name of the file in the sorted list of instrumented files. +func fileFromIndex(i int) (string, error) { + total := len(globalData.files) + if i < 0 || i >= total { + return "", fmt.Errorf("file index out of range: [%d] with length %d", i, total) + } + return globalData.files[i], nil +} + +// blockFromIndex returns the i-th block in the given file. +func blockFromIndex(file string, i int) (testing.CoverBlock, error) { + blocks, ok := coverdata.Cover.Blocks[file] + if !ok { + return testing.CoverBlock{}, fmt.Errorf("instrumented file %s does not exist", file) + } + total := len(blocks) + if i < 0 || i >= total { + return testing.CoverBlock{}, fmt.Errorf("block index out of range: [%d] with length %d", i, total) + } + return blocks[i], nil +} + +func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) { + io.WriteString(out, fmt.Sprintf("%#x\n", pc)) + io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1)) +} diff --git a/pkg/shim/v2/options/BUILD b/pkg/crypto/BUILD index ca212e874..08fa772ca 100644 --- a/pkg/shim/v2/options/BUILD +++ b/pkg/crypto/BUILD @@ -3,9 +3,10 @@ load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( - name = "options", + name = "crypto", srcs = [ - "options.go", + "crypto.go", + "crypto_stdlib.go", ], visibility = ["//:sandbox"], ) diff --git a/pkg/sleep/empty.s b/pkg/crypto/crypto.go index fb37360ac..b26b55d37 100644 --- a/pkg/sleep/empty.s +++ b/pkg/crypto/crypto.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,4 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Empty assembly file so empty func definitions work. +// Package crypto wraps crypto primitives. +package crypto diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go new file mode 100644 index 000000000..74a55a123 --- /dev/null +++ b/pkg/crypto/crypto_stdlib.go @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package crypto + +import ( + "crypto/ecdsa" + "crypto/sha512" + "math/big" +) + +// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the +// public key, pub. Its return value records whether the signature is valid. +func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { + return ecdsa.Verify(pub, hash, r, s) +} + +// SumSha384 returns the SHA384 checksum of the data. +func SumSha384(data []byte) (sum384 [sha512.Size384]byte) { + return sha512.Sum384(data) +} diff --git a/pkg/fdchannel/fdchannel_unsafe.go b/pkg/fdchannel/fdchannel_unsafe.go index 367235be5..b253a8fdd 100644 --- a/pkg/fdchannel/fdchannel_unsafe.go +++ b/pkg/fdchannel/fdchannel_unsafe.go @@ -21,7 +21,6 @@ package fdchannel import ( "fmt" "reflect" - "sync/atomic" "syscall" "unsafe" ) @@ -41,7 +40,7 @@ func NewConnectedSockets() ([2]int, error) { // // Endpoint is not copyable or movable by value. type Endpoint struct { - sockfd int32 // accessed using atomic memory operations + sockfd int32 msghdr syscall.Msghdr cmsg *syscall.Cmsghdr // followed by sizeofInt32 bytes of data } @@ -54,10 +53,10 @@ func (ep *Endpoint) Init(sockfd int) { // sendmsg+recvmsg for a zero-length datagram is slightly faster than // sendmsg+recvmsg for a single byte over a stream socket. cmsgSlice := make([]byte, syscall.CmsgSpace(sizeofInt32)) - cmsgReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&cmsgSlice)) + cmsgReflect := (*reflect.SliceHeader)(unsafe.Pointer(&cmsgSlice)) ep.sockfd = int32(sockfd) - ep.msghdr.Control = (*byte)((unsafe.Pointer)(cmsgReflect.Data)) - ep.cmsg = (*syscall.Cmsghdr)((unsafe.Pointer)(cmsgReflect.Data)) + ep.msghdr.Control = (*byte)(unsafe.Pointer(cmsgReflect.Data)) + ep.cmsg = (*syscall.Cmsghdr)(unsafe.Pointer(cmsgReflect.Data)) // ep.msghdr.Controllen and ep.cmsg.* are mutated by recvmsg(2), so they're // set before calling sendmsg/recvmsg. } @@ -73,12 +72,8 @@ func NewEndpoint(sockfd int) *Endpoint { // Destroy releases resources owned by ep. No other Endpoint methods may be // called after Destroy. func (ep *Endpoint) Destroy() { - // These need not use sync/atomic since there must not be any concurrent - // calls to Endpoint methods. - if ep.sockfd >= 0 { - syscall.Close(int(ep.sockfd)) - ep.sockfd = -1 - } + syscall.Close(int(ep.sockfd)) + ep.sockfd = -1 } // Shutdown causes concurrent and future calls to ep.SendFD(), ep.RecvFD(), and @@ -88,10 +83,7 @@ func (ep *Endpoint) Destroy() { // Shutdown is the only Endpoint method that may be called concurrently with // other methods. func (ep *Endpoint) Shutdown() { - if sockfd := int(atomic.SwapInt32(&ep.sockfd, -1)); sockfd >= 0 { - syscall.Shutdown(sockfd, syscall.SHUT_RDWR) - syscall.Close(sockfd) - } + syscall.Shutdown(int(ep.sockfd), syscall.SHUT_RDWR) } // SendFD sends the open file description represented by the given file @@ -103,7 +95,7 @@ func (ep *Endpoint) SendFD(fd int) error { ep.cmsg.SetLen(cmsgLen) *ep.cmsgData() = int32(fd) ep.msghdr.SetControllen(cmsgLen) - _, _, e := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(atomic.LoadInt32(&ep.sockfd)), uintptr((unsafe.Pointer)(&ep.msghdr)), 0) + _, _, e := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), 0) if e != 0 { return e } @@ -113,7 +105,7 @@ func (ep *Endpoint) SendFD(fd int) error { // RecvFD receives an open file description from the connected Endpoint and // returns a file descriptor representing it, owned by the caller. func (ep *Endpoint) RecvFD() (int, error) { - return ep.recvFD(0) + return ep.recvFD(false) } // RecvFDNonblock receives an open file description from the connected Endpoint @@ -121,13 +113,18 @@ func (ep *Endpoint) RecvFD() (int, error) { // are no pending receivable open file descriptions, RecvFDNonblock returns // (<unspecified>, EAGAIN or EWOULDBLOCK). func (ep *Endpoint) RecvFDNonblock() (int, error) { - return ep.recvFD(syscall.MSG_DONTWAIT) + return ep.recvFD(true) } -func (ep *Endpoint) recvFD(flags uintptr) (int, error) { +func (ep *Endpoint) recvFD(nonblock bool) (int, error) { cmsgLen := syscall.CmsgLen(sizeofInt32) ep.msghdr.SetControllen(cmsgLen) - _, _, e := syscall.Syscall(syscall.SYS_RECVMSG, uintptr(atomic.LoadInt32(&ep.sockfd)), uintptr((unsafe.Pointer)(&ep.msghdr)), flags|syscall.MSG_TRUNC) + var e syscall.Errno + if nonblock { + _, _, e = syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), syscall.MSG_TRUNC|syscall.MSG_DONTWAIT) + } else { + _, _, e = syscall.Syscall(syscall.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), syscall.MSG_TRUNC) + } if e != 0 { return -1, e } @@ -142,5 +139,5 @@ func (ep *Endpoint) recvFD(flags uintptr) (int, error) { func (ep *Endpoint) cmsgData() *int32 { // syscall.CmsgLen(0) == syscall.cmsgAlignOf(syscall.SizeofCmsghdr) - return (*int32)((unsafe.Pointer)(uintptr((unsafe.Pointer)(ep.cmsg)) + uintptr(syscall.CmsgLen(0)))) + return (*int32)(unsafe.Pointer(uintptr(unsafe.Pointer(ep.cmsg)) + uintptr(syscall.CmsgLen(0)))) } diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index aa8e4e1f3..cc31d0175 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -11,7 +11,8 @@ go_library( "futex_linux.go", "io.go", "packet_window_allocator.go", - "packet_window_mmap.go", + "packet_window_mmap_amd64.go", + "packet_window_mmap_arm64.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index e7c3a3a0b..2e8452a02 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -40,17 +40,41 @@ func (ep *Endpoint) ctrlInit(opts ...EndpointOption) error { return nil } -type ctrlHandshakeRequest struct{} - -type ctrlHandshakeResponse struct{} - func (ep *Endpoint) ctrlConnect() error { if err := ep.enterFutexWait(); err != nil { return err } - _, err := ep.futexConnect(&ctrlHandshakeRequest{}) - ep.exitFutexWait() - return err + defer ep.exitFutexWait() + + // Write the connection request. + w := ep.NewWriter() + if err := json.NewEncoder(w).Encode(struct{}{}); err != nil { + return fmt.Errorf("error writing connection request: %v", err) + } + *ep.dataLen() = w.Len() + + // Exchange control with the server. + if err := ep.futexSetPeerActive(); err != nil { + return err + } + if err := ep.futexWakePeer(); err != nil { + return err + } + if err := ep.futexWaitUntilActive(); err != nil { + return err + } + + // Read the connection response. + var resp struct{} + respLen := atomic.LoadUint32(ep.dataLen()) + if respLen > ep.dataCap { + return fmt.Errorf("invalid connection response length %d (maximum %d)", respLen, ep.dataCap) + } + if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil { + return fmt.Errorf("error reading connection response: %v", err) + } + + return nil } func (ep *Endpoint) ctrlWaitFirst() error { @@ -59,52 +83,61 @@ func (ep *Endpoint) ctrlWaitFirst() error { } defer ep.exitFutexWait() - // Wait for the handshake request. - if err := ep.futexSwitchFromPeer(); err != nil { + // Wait for the connection request. + if err := ep.futexWaitUntilActive(); err != nil { return err } - // Read the handshake request. + // Read the connection request. reqLen := atomic.LoadUint32(ep.dataLen()) if reqLen > ep.dataCap { - return fmt.Errorf("invalid handshake request length %d (maximum %d)", reqLen, ep.dataCap) + return fmt.Errorf("invalid connection request length %d (maximum %d)", reqLen, ep.dataCap) } - var req ctrlHandshakeRequest + var req struct{} if err := json.NewDecoder(ep.NewReader(reqLen)).Decode(&req); err != nil { - return fmt.Errorf("error reading handshake request: %v", err) + return fmt.Errorf("error reading connection request: %v", err) } - // Write the handshake response. + // Write the connection response. w := ep.NewWriter() - if err := json.NewEncoder(w).Encode(ctrlHandshakeResponse{}); err != nil { - return fmt.Errorf("error writing handshake response: %v", err) + if err := json.NewEncoder(w).Encode(struct{}{}); err != nil { + return fmt.Errorf("error writing connection response: %v", err) } *ep.dataLen() = w.Len() // Return control to the client. raceBecomeInactive() - if err := ep.futexSwitchToPeer(); err != nil { + if err := ep.futexSetPeerActive(); err != nil { + return err + } + if err := ep.futexWakePeer(); err != nil { return err } - // Wait for the first non-handshake message. - return ep.futexSwitchFromPeer() + // Wait for the first non-connection message. + return ep.futexWaitUntilActive() } func (ep *Endpoint) ctrlRoundTrip() error { - if err := ep.futexSwitchToPeer(); err != nil { + if err := ep.enterFutexWait(); err != nil { return err } - if err := ep.enterFutexWait(); err != nil { + defer ep.exitFutexWait() + + if err := ep.futexSetPeerActive(); err != nil { return err } - err := ep.futexSwitchFromPeer() - ep.exitFutexWait() - return err + if err := ep.futexWakePeer(); err != nil { + return err + } + return ep.futexWaitUntilActive() } func (ep *Endpoint) ctrlWakeLast() error { - return ep.futexSwitchToPeer() + if err := ep.futexSetPeerActive(); err != nil { + return err + } + return ep.futexWakePeer() } func (ep *Endpoint) enterFutexWait() error { diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index ac974b232..580bf23a4 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -41,11 +41,11 @@ const ( ) func (ep *Endpoint) connState() *uint32 { - return (*uint32)((unsafe.Pointer)(ep.packet)) + return (*uint32)(unsafe.Pointer(ep.packet)) } func (ep *Endpoint) dataLen() *uint32 { - return (*uint32)((unsafe.Pointer)(ep.packet + 4)) + return (*uint32)(unsafe.Pointer(ep.packet + 4)) } // Data returns the datagram part of ep's packet window as a byte slice. @@ -63,7 +63,7 @@ func (ep *Endpoint) dataLen() *uint32 { // all. func (ep *Endpoint) Data() []byte { var bs []byte - bsReflect := (*reflect.SliceHeader)((unsafe.Pointer)(&bs)) + bsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) bsReflect.Data = ep.packet + PacketHeaderBytes bsReflect.Len = int(ep.dataCap) bsReflect.Cap = int(ep.dataCap) @@ -76,12 +76,12 @@ var ioSync int64 func raceBecomeActive() { if sync.RaceEnabled { - sync.RaceAcquire((unsafe.Pointer)(&ioSync)) + sync.RaceAcquire(unsafe.Pointer(&ioSync)) } } func raceBecomeInactive() { if sync.RaceEnabled { - sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + sync.RaceReleaseMerge(unsafe.Pointer(&ioSync)) } } diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index 168c1ccff..0e559ee16 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -17,7 +17,6 @@ package flipcall import ( - "encoding/json" "fmt" "runtime" "sync/atomic" @@ -26,55 +25,26 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeResponse, error) { - var resp ctrlHandshakeResponse - - // Write the handshake request. - w := ep.NewWriter() - if err := json.NewEncoder(w).Encode(req); err != nil { - return resp, fmt.Errorf("error writing handshake request: %v", err) - } - *ep.dataLen() = w.Len() - - // Exchange control with the server. - if err := ep.futexSwitchToPeer(); err != nil { - return resp, err +func (ep *Endpoint) futexSetPeerActive() error { + if atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { + return nil } - if err := ep.futexSwitchFromPeer(); err != nil { - return resp, err + switch cs := atomic.LoadUint32(ep.connState()); cs { + case csShutdown: + return ShutdownError{} + default: + return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) } - - // Read the handshake response. - respLen := atomic.LoadUint32(ep.dataLen()) - if respLen > ep.dataCap { - return resp, fmt.Errorf("invalid handshake response length %d (maximum %d)", respLen, ep.dataCap) - } - if err := json.NewDecoder(ep.NewReader(respLen)).Decode(&resp); err != nil { - return resp, fmt.Errorf("error reading handshake response: %v", err) - } - - return resp, nil } -func (ep *Endpoint) futexSwitchToPeer() error { - // Update connection state to indicate that the peer should be active. - if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { - switch cs := atomic.LoadUint32(ep.connState()); cs { - case csShutdown: - return ShutdownError{} - default: - return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) - } - } - - // Wake the peer's Endpoint.futexSwitchFromPeer(). +func (ep *Endpoint) futexWakePeer() error { if err := ep.futexWakeConnState(1); err != nil { return fmt.Errorf("failed to FUTEX_WAKE peer Endpoint: %v", err) } return nil } -func (ep *Endpoint) futexSwitchFromPeer() error { +func (ep *Endpoint) futexWaitUntilActive() error { for { switch cs := atomic.LoadUint32(ep.connState()); cs { case ep.activeState: diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap_amd64.go index 869183b11..869183b11 100644 --- a/pkg/flipcall/packet_window_mmap.go +++ b/pkg/flipcall/packet_window_mmap_amd64.go diff --git a/pkg/flipcall/packet_window_mmap_arm64.go b/pkg/flipcall/packet_window_mmap_arm64.go new file mode 100644 index 000000000..b9c9c44f6 --- /dev/null +++ b/pkg/flipcall/packet_window_mmap_arm64.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package flipcall + +import ( + "syscall" +) + +// Return a memory mapping of the pwd in memory that can be shared outside the sandbox. +func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) { + m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) + return m, err +} diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index 7a82631c5..08832a8ae 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -8,9 +8,8 @@ go_library( "goid.go", "goid_amd64.s", "goid_arm64.s", - "goid_race.go", - "goid_unsafe.go", ], + stateify = False, visibility = ["//visibility:public"], ) @@ -18,7 +17,6 @@ go_test( name = "goid_test", size = "small", srcs = [ - "empty_test.go", "goid_test.go", ], library = ":goid", diff --git a/pkg/goid/goid.go b/pkg/goid/goid.go index 39df30031..17c384cb0 100644 --- a/pkg/goid/goid.go +++ b/pkg/goid/goid.go @@ -12,13 +12,61 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !race +// +build go1.12 +// +build !go1.17 -// Package goid provides access to the ID of the current goroutine in -// race/gotsan builds. +// Check type signatures when updating Go version. + +// Package goid provides the Get function. package goid // Get returns the ID of the current goroutine. func Get() int64 { - panic("unimplemented for non-race builds") + return getg().goid +} + +// Structs from Go runtime. These may change in the future and require +// updating. These structs are currently the same on both AMD64 and ARM64, +// but may diverge in the future. + +type stack struct { + lo uintptr + hi uintptr +} + +type gobuf struct { + sp uintptr + pc uintptr + g uintptr + ctxt uintptr + ret uint64 + lr uintptr + bp uintptr } + +type g struct { + stack stack + stackguard0 uintptr + stackguard1 uintptr + + _panic uintptr + _defer uintptr + m uintptr + sched gobuf + syscallsp uintptr + syscallpc uintptr + stktopsp uintptr + param uintptr + atomicstatus uint32 + stackLock uint32 + goid int64 + + // More fields... + // + // We only use goid and the fields before it are only listed to + // calculate the correct offset. +} + +// Defined in assembly. This can't use go:linkname since runtime.getg() isn't a +// real function, it's a compiler intrinsic. +func getg() *g diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go index 31970ce79..54be11d63 100644 --- a/pkg/goid/goid_test.go +++ b/pkg/goid/goid_test.go @@ -12,63 +12,70 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build race - package goid import ( "runtime" "sync" "testing" + "time" ) -func TestInitialGoID(t *testing.T) { - const max = 10000 - if id := goid(); id < 0 || id > max { - t.Errorf("got goid = %d, want 0 < goid <= %d", id, max) - } -} +func TestUniquenessAndConsistency(t *testing.T) { + const ( + numGoroutines = 5000 -// TestGoIDSquence verifies that goid returns values which could plausibly be -// goroutine IDs. If this test breaks or becomes flaky, the structs in -// goid_unsafe.go may need to be updated. -func TestGoIDSquence(t *testing.T) { - // Goroutine IDs are cached by each P. - runtime.GOMAXPROCS(1) + // maxID is not an intrinsic property of goroutine IDs; it is only a + // property of how the Go runtime currently assigns them. Future + // changes to the Go runtime may require that maxID be raised, or that + // assertions regarding it be removed entirely. + maxID = numGoroutines + 1000 + ) - // Fill any holes in lower range. - for i := 0; i < 50; i++ { - var wg sync.WaitGroup - wg.Add(1) + var ( + goidsMu sync.Mutex + goids = make(map[int64]struct{}) + checkedWG sync.WaitGroup + exitCh = make(chan struct{}) + ) + for i := 0; i < numGoroutines; i++ { + checkedWG.Add(1) go func() { - wg.Done() - - // Leak the goroutine to prevent the ID from being - // reused. - select {} - }() - wg.Wait() - } - - id := goid() - for i := 0; i < 100; i++ { - var ( - newID int64 - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - newID = goid() - wg.Done() - - // Leak the goroutine to prevent the ID from being - // reused. - select {} + id := Get() + if id > maxID { + t.Errorf("observed unexpectedly large goroutine ID %d", id) + } + goidsMu.Lock() + if _, dup := goids[id]; dup { + t.Errorf("observed duplicate goroutine ID %d", id) + } + goids[id] = struct{}{} + goidsMu.Unlock() + checkedWG.Done() + for { + if curID := Get(); curID != id { + t.Errorf("goroutine ID changed from %d to %d", id, curID) + // Don't spam logs by repeating the check; wait quietly for + // the test to finish. + <-exitCh + return + } + // Check if the test is over. + select { + case <-exitCh: + return + default: + } + // Yield to other goroutines, and possibly migrate to another P. + runtime.Gosched() + } }() - wg.Wait() - if max := id + 100; newID <= id || newID > max { - t.Errorf("unexpected goroutine ID pattern, got goid = %d, want %d < goid <= %d (previous = %d)", newID, id, max, id) - } - id = newID } + // Wait for all goroutines to perform uniqueness checks. + checkedWG.Wait() + // Wait for an additional second to allow goroutines to spin checking for + // ID consistency. + time.Sleep(time.Second) + // Request that all goroutines exit. + close(exitCh) } diff --git a/pkg/goid/goid_unsafe.go b/pkg/goid/goid_unsafe.go deleted file mode 100644 index ded8004dd..000000000 --- a/pkg/goid/goid_unsafe.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package goid - -// Structs from Go runtime. These may change in the future and require -// updating. These structs are currently the same on both AMD64 and ARM64, -// but may diverge in the future. - -type stack struct { - lo uintptr - hi uintptr -} - -type gobuf struct { - sp uintptr - pc uintptr - g uintptr - ctxt uintptr - ret uint64 - lr uintptr - bp uintptr -} - -type g struct { - stack stack - stackguard0 uintptr - stackguard1 uintptr - - _panic uintptr - _defer uintptr - m uintptr - sched gobuf - syscallsp uintptr - syscallpc uintptr - stktopsp uintptr - param uintptr - atomicstatus uint32 - stackLock uint32 - goid int64 - - // More fields... - // - // We only use goid and the fields before it are only listed to - // calculate the correct offset. -} - -func getg() *g - -// goid returns the ID of the current goroutine. -func goid() int64 { - return getg().goid -} diff --git a/pkg/log/json.go b/pkg/log/json.go index bdf9d691e..8c52dcc87 100644 --- a/pkg/log/json.go +++ b/pkg/log/json.go @@ -27,8 +27,8 @@ type jsonLog struct { } // MarshalJSON implements json.Marshaler.MarashalJSON. -func (lv Level) MarshalJSON() ([]byte, error) { - switch lv { +func (l Level) MarshalJSON() ([]byte, error) { + switch l { case Warning: return []byte(`"warning"`), nil case Info: @@ -36,20 +36,20 @@ func (lv Level) MarshalJSON() ([]byte, error) { case Debug: return []byte(`"debug"`), nil default: - return nil, fmt.Errorf("unknown level %v", lv) + return nil, fmt.Errorf("unknown level %v", l) } } // UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal // from both string names and integers. -func (lv *Level) UnmarshalJSON(b []byte) error { +func (l *Level) UnmarshalJSON(b []byte) error { switch s := string(b); s { case "0", `"warning"`: - *lv = Warning + *l = Warning case "1", `"info"`: - *lv = Info + *l = Info case "2", `"debug"`: - *lv = Debug + *l = Debug default: return fmt.Errorf("unknown level %q", s) } diff --git a/pkg/log/log.go b/pkg/log/log.go index 37e0605ad..2e3408357 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -356,7 +356,7 @@ func CopyStandardLogTo(l Level) error { case Warning: f = Warningf default: - return fmt.Errorf("Unknown log level %v", l) + return fmt.Errorf("unknown log level %v", l) } stdlog.SetOutput(linewriter.NewWriter(func(p []byte) { diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD index a8fcb2e19..501a9ef21 100644 --- a/pkg/merkletree/BUILD +++ b/pkg/merkletree/BUILD @@ -6,12 +6,18 @@ go_library( name = "merkletree", srcs = ["merkletree.go"], visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) go_test( name = "merkletree_test", srcs = ["merkletree_test.go"], library = ":merkletree", - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index d8227b8bd..aea7dde38 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -18,21 +18,33 @@ package merkletree import ( "bytes" "crypto/sha256" + "crypto/sha512" + "encoding/gob" "fmt" "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) const ( // sha256DigestSize specifies the digest size of a SHA256 hash. sha256DigestSize = 32 + // sha512DigestSize specifies the digest size of a SHA512 hash. + sha512DigestSize = 64 ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). -func DigestSize() int { - return sha256DigestSize +// TODO(b/156980949): Allow config SHA384. +func DigestSize(hashAlgorithm int) int { + switch hashAlgorithm { + case linux.FS_VERITY_HASH_ALG_SHA256: + return sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + return sha512DigestSize + default: + return -1 + } } // Layout defines the scale of a Merkle tree. @@ -51,11 +63,19 @@ type Layout struct { // InitLayout initializes and returns a new Layout object describing the structure // of a tree. dataSize specifies the size of input data in bytes. -func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { +func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) { layout := Layout{ blockSize: usermem.PageSize, - // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). - digestSize: sha256DigestSize, + } + + // TODO(b/156980949): Allow config SHA384. + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + layout.digestSize = sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + layout.digestSize = sha512DigestSize + default: + return Layout{}, fmt.Errorf("unexpected hash algorithms") } // treeStart is the offset (in bytes) of the first level of the tree in @@ -88,7 +108,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { } layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) - return layout + return layout, nil } // hashesPerBlock() returns the number of digests in each block. For example, @@ -128,23 +148,49 @@ func (layout Layout) blockOffset(level int, index int64) int64 { // meatadata. type VerityDescriptor struct { Name string + FileSize int64 Mode uint32 UID uint32 GID uint32 + Children map[string]struct{} RootHash []byte } func (d *VerityDescriptor) String() string { - return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash) + b := new(bytes.Buffer) + e := gob.NewEncoder(b) + e.Encode(d.Children) + return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, Children: %v, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, b.Bytes(), d.RootHash) } // verify generates a hash from d, and compares it with expected. -func (d *VerityDescriptor) verify(expected []byte) error { - h := sha256.Sum256([]byte(d.String())) +func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error { + h, err := hashData([]byte(d.String()), hashAlgorithms) + if err != nil { + return err + } if !bytes.Equal(h[:], expected) { return fmt.Errorf("unexpected root hash") } return nil + +} + +// hashData hashes data and returns the result hash based on the hash +// algorithms. +func hashData(data []byte, hashAlgorithms int) ([]byte, error) { + var digest []byte + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + digestArray := sha256.Sum256(data) + digest = digestArray[:] + case linux.FS_VERITY_HASH_ALG_SHA512: + digestArray := sha512.Sum512(data) + digest = digestArray[:] + default: + return nil, fmt.Errorf("unexpected hash algorithms") + } + return digest, nil } // GenerateParams contains the parameters used to generate a Merkle tree. @@ -161,6 +207,11 @@ type GenerateParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // Children is a map of children names for a directory. It should be + // empty for a regular file. + Children map[string]struct{} + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // TreeReader is a reader for the Merkle tree. TreeReader io.ReaderAt // TreeWriter is a writer for the Merkle tree. @@ -176,7 +227,10 @@ type GenerateParams struct { // Generate returns a hash of a VerityDescriptor, which contains the file // metadata and the hash from file content. func Generate(params *GenerateParams) ([]byte, error) { - layout := InitLayout(params.Size, params.DataAndTreeInSameFile) + layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return nil, err + } numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize @@ -218,10 +272,13 @@ func Generate(params *GenerateParams) ([]byte, error) { return nil, err } // Hash the bytes in buf. - digest := sha256.Sum256(buf) + digest, err := hashData(buf, params.HashAlgorithms) + if err != nil { + return nil, err + } if level == layout.rootLevel() { - root = digest[:] + root = digest } // Write the generated hash to the end of the tree file. @@ -241,13 +298,14 @@ func Generate(params *GenerateParams) ([]byte, error) { } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, RootHash: root, } - ret := sha256.Sum256([]byte(descriptor.String())) - return ret[:], nil + return hashData([]byte(descriptor.String()), params.HashAlgorithms) } // VerifyParams contains the params used to verify a portion of a file against @@ -269,6 +327,11 @@ type VerifyParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // Children is a map of children names for a directory. It should be + // empty for a regular file. + Children map[string]struct{} + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // ReadOffset is the offset of the data range to be verified. ReadOffset int64 // ReadSize is the size of the data range to be verified. @@ -287,18 +350,24 @@ type VerifyParams struct { // For verifyMetadata, params.data is not needed. It only accesses params.tree // for the raw root hash. func verifyMetadata(params *VerifyParams, layout *Layout) error { - root := make([]byte, layout.digestSize) - if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil { - return fmt.Errorf("failed to read root hash: %w", err) + var root []byte + // Only read the root hash if we expect that the Merkle tree file is non-empty. + if params.Size != 0 { + root = make([]byte, layout.digestSize) + if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil { + return fmt.Errorf("failed to read root hash: %w", err) + } } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, RootHash: root, } - return descriptor.verify(params.Expected) + return descriptor.verify(params.Expected, params.HashAlgorithms) } // Verify verifies the content read from data with offset. The content is @@ -313,7 +382,10 @@ func Verify(params *VerifyParams) (int64, error) { if params.ReadSize < 0 { return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize) } - layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile) + layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return 0, err + } if params.ReadSize == 0 { return 0, verifyMetadata(params, &layout) } @@ -349,12 +421,14 @@ func Verify(params *VerifyParams) (int64, error) { } } descriptor := VerityDescriptor{ - Name: params.Name, - Mode: params.Mode, - UID: params.UID, - GID: params.GID, + Name: params.Name, + FileSize: params.Size, + Mode: params.Mode, + UID: params.UID, + GID: params.GID, + Children: params.Children, } - if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil { + if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { return 0, err } @@ -395,7 +469,7 @@ func Verify(params *VerifyParams) (int64, error) { // fails if the calculated hash from block is different from any level of // hashes stored in tree. And the final root hash is compared with // expected. -func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error { +func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error { if len(dataBlock) != int(layout.blockSize) { return fmt.Errorf("incorrect block size") } @@ -406,8 +480,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, for level := 0; level < layout.numLevels(); level++ { // Calculate hash. if level == 0 { - digestArray := sha256.Sum256(dataBlock) - digest = digestArray[:] + h, err := hashData(dataBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } else { // Read a block in previous level that contains the // hash we just generated, and generate a next level @@ -415,8 +492,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil { return err } - digestArray := sha256.Sum256(treeBlock) - digest = digestArray[:] + h, err := hashData(treeBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } // Read the digest for the current block and store in @@ -434,5 +514,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, // Verification for the tree succeeded. Now hash the descriptor with // the root hash and compare it with expected. descriptor.RootHash = digest - return descriptor.verify(expected) + return descriptor.verify(expected, hashAlgorithms) } diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index e1350ebda..66ddf09e6 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -16,60 +16,134 @@ package merkletree import ( "bytes" + "errors" "fmt" "io" "math/rand" "testing" "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) func TestLayout(t *testing.T) { testCases := []struct { + name string dataSize int64 + hashAlgorithms int dataAndTreeInSameFile bool + expectedDigestSize int64 expectedLevelOffset []int64 }{ { + name: "SmallSizeSHA256SeparateFile", dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0}, }, { + name: "SmallSizeSHA512SeparateFile", dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0}, + }, + { + name: "SmallSizeSHA256SameFile", + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{usermem.PageSize}, }, { + name: "SmallSizeSHA512SameFile", + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{usermem.PageSize}, + }, + { + name: "MiddleSizeSHA256SeparateFile", dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, }, { + name: "MiddleSizeSHA512SeparateFile", + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize}, + }, + { + name: "MiddleSizeSHA256SameFile", dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, }, { + name: "MiddleSizeSHA512SameFile", + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize}, + }, + { + name: "LargeSizeSHA256SeparateFile", dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, }, { + name: "LargeSizeSHA512SeparateFile", dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize}, + }, + { + name: "LargeSizeSHA256SameFile", + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, + { + name: "LargeSizeSHA512SameFile", + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize}, + }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) + t.Run(tc.name, func(t *testing.T) { + l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile) + if err != nil { + t.Fatalf("Failed to InitLayout: %v", err) + } if l.blockSize != int64(usermem.PageSize) { t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } - if l.digestSize != sha256DigestSize { + if l.digestSize != tc.expectedDigestSize { t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } if l.numLevels() != len(tc.expectedLevelOffset) { @@ -118,404 +192,916 @@ func TestGenerate(t *testing.T) { // The input data has size dataSize. It starts with the data in startWith, // and all other bytes are zeroes. testCases := []struct { - data []byte - expectedHash []byte + name string + data []byte + hashAlgorithms int + dataAndTreeInSameFile bool + expectedHash []byte }{ { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + name: "OnePageZeroesSHA256SeparateFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{42, 197, 191, 52, 206, 122, 93, 34, 198, 125, 100, 154, 171, 177, 94, 14, 49, 40, 76, 157, 122, 58, 78, 6, 163, 248, 30, 238, 16, 190, 173, 175}, + }, + { + name: "OnePageZeroesSHA256SameFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{42, 197, 191, 52, 206, 122, 93, 34, 198, 125, 100, 154, 171, 177, 94, 14, 49, 40, 76, 157, 122, 58, 78, 6, 163, 248, 30, 238, 16, 190, 173, 175}, + }, + { + name: "OnePageZeroesSHA512SeparateFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{87, 131, 150, 74, 0, 218, 117, 114, 34, 23, 212, 16, 122, 97, 124, 172, 41, 46, 107, 150, 33, 46, 56, 39, 5, 246, 215, 187, 140, 83, 35, 63, 111, 74, 155, 241, 161, 214, 92, 141, 232, 125, 99, 71, 168, 102, 82, 20, 229, 249, 248, 28, 29, 238, 199, 223, 173, 180, 179, 46, 241, 240, 237, 74}, + }, + { + name: "OnePageZeroesSHA512SameFile", + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{87, 131, 150, 74, 0, 218, 117, 114, 34, 23, 212, 16, 122, 97, 124, 172, 41, 46, 107, 150, 33, 46, 56, 39, 5, 246, 215, 187, 140, 83, 35, 63, 111, 74, 155, 241, 161, 214, 92, 141, 232, 125, 99, 71, 168, 102, 82, 20, 229, 249, 248, 28, 29, 238, 199, 223, 173, 180, 179, 46, 241, 240, 237, 74}, + }, + { + name: "MultiplePageZeroesSHA256SeparateFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{115, 151, 35, 147, 223, 91, 17, 6, 162, 145, 237, 81, 88, 53, 120, 49, 128, 70, 188, 28, 254, 241, 19, 233, 30, 243, 71, 225, 57, 58, 61, 38}, + }, + { + name: "MultiplePageZeroesSHA256SameFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{115, 151, 35, 147, 223, 91, 17, 6, 162, 145, 237, 81, 88, 53, 120, 49, 128, 70, 188, 28, 254, 241, 19, 233, 30, 243, 71, 225, 57, 58, 61, 38}, + }, + { + name: "MultiplePageZeroesSHA512SeparateFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{41, 94, 205, 97, 254, 226, 171, 69, 76, 102, 197, 47, 113, 53, 24, 244, 103, 131, 83, 73, 87, 212, 247, 140, 32, 144, 211, 158, 25, 131, 194, 57, 21, 224, 128, 119, 69, 100, 45, 50, 157, 54, 46, 214, 152, 179, 59, 78, 28, 48, 146, 160, 204, 48, 27, 90, 152, 193, 167, 45, 150, 67, 66, 217}, + }, + { + name: "MultiplePageZeroesSHA512SameFile", + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{41, 94, 205, 97, 254, 226, 171, 69, 76, 102, 197, 47, 113, 53, 24, 244, 103, 131, 83, 73, 87, 212, 247, 140, 32, 144, 211, 158, 25, 131, 194, 57, 21, 224, 128, 119, 69, 100, 45, 50, 157, 54, 46, 214, 152, 179, 59, 78, 28, 48, 146, 160, 204, 48, 27, 90, 152, 193, 167, 45, 150, 67, 66, 217}, + }, + { + name: "SingleASHA256SeparateFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{52, 159, 140, 206, 140, 138, 231, 140, 94, 14, 252, 66, 175, 128, 191, 14, 52, 215, 190, 184, 165, 50, 182, 224, 42, 156, 145, 0, 1, 15, 187, 85}, + }, + { + name: "SingleASHA256SameFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{52, 159, 140, 206, 140, 138, 231, 140, 94, 14, 252, 66, 175, 128, 191, 14, 52, 215, 190, 184, 165, 50, 182, 224, 42, 156, 145, 0, 1, 15, 187, 85}, + }, + { + name: "SingleASHA512SeparateFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{232, 90, 223, 95, 60, 151, 149, 172, 174, 58, 206, 97, 189, 103, 6, 202, 67, 248, 1, 189, 243, 51, 250, 42, 5, 89, 195, 9, 50, 74, 39, 169, 114, 228, 109, 225, 128, 210, 63, 94, 18, 133, 58, 48, 225, 100, 176, 55, 87, 60, 235, 224, 143, 41, 15, 253, 94, 28, 251, 233, 99, 207, 152, 108}, + }, + { + name: "SingleASHA512SameFile", + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{232, 90, 223, 95, 60, 151, 149, 172, 174, 58, 206, 97, 189, 103, 6, 202, 67, 248, 1, 189, 243, 51, 250, 42, 5, 89, 195, 9, 50, 74, 39, 169, 114, 228, 109, 225, 128, 210, 63, 94, 18, 133, 58, 48, 225, 100, 176, 55, 87, 60, 235, 224, 143, 41, 15, 253, 94, 28, 251, 233, 99, 207, 152, 108}, + }, + { + name: "OnePageASHA256SeparateFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + expectedHash: []byte{157, 60, 139, 54, 248, 39, 187, 77, 31, 107, 241, 26, 240, 49, 83, 159, 182, 60, 128, 85, 121, 204, 15, 249, 44, 248, 127, 134, 58, 220, 41, 185}, }, { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + name: "OnePageASHA256SameFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedHash: []byte{157, 60, 139, 54, 248, 39, 187, 77, 31, 107, 241, 26, 240, 49, 83, 159, 182, 60, 128, 85, 121, 204, 15, 249, 44, 248, 127, 134, 58, 220, 41, 185}, }, { - data: []byte{'a'}, - expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, + name: "OnePageASHA512SeparateFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedHash: []byte{116, 22, 252, 100, 32, 241, 254, 228, 167, 228, 110, 146, 156, 189, 6, 30, 27, 127, 94, 181, 15, 98, 173, 60, 34, 102, 92, 174, 181, 80, 205, 90, 88, 12, 125, 194, 148, 175, 184, 168, 37, 66, 127, 194, 19, 132, 93, 147, 168, 217, 227, 131, 100, 25, 213, 255, 132, 60, 196, 217, 24, 158, 1, 50}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + name: "OnePageASHA512SameFile", + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedHash: []byte{116, 22, 252, 100, 32, 241, 254, 228, 167, 228, 110, 146, 156, 189, 6, 30, 27, 127, 94, 181, 15, 98, 173, 60, 34, 102, 92, 174, 181, 80, 205, 90, 88, 12, 125, 194, 148, 175, 184, 168, 37, 66, 127, 194, 19, 132, 93, 147, 168, 217, 227, 131, 100, 25, 213, 255, 132, 60, 196, 217, 24, 158, 1, 50}, }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) { - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - params := GenerateParams{ - Size: int64(len(tc.data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(tc.data) - params.File = &tree - } else { - params.File = &bytesReadWriter{ - bytes: tc.data, - } - } - hash, err := Generate(¶ms) - if err != nil { - t.Fatalf("Got err: %v, want nil", err) + t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { + var tree bytesReadWriter + params := GenerateParams{ + Size: int64(len(tc.data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + Children: make(map[string]struct{}), + HashAlgorithms: tc.hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: tc.dataAndTreeInSameFile, + } + if tc.dataAndTreeInSameFile { + tree.Write(tc.data) + params.File = &tree + } else { + params.File = &bytesReadWriter{ + bytes: tc.data, } + } + hash, err := Generate(¶ms) + if err != nil { + t.Fatalf("Got err: %v, want nil", err) + } + if !bytes.Equal(hash, tc.expectedHash) { + t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash) + } + }) + } +} - if !bytes.Equal(hash, tc.expectedHash) { - t.Errorf("Got hash: %v, want %v", hash, tc.expectedHash) - } +// prepareVerify generates test data and corresponding Merkle tree, and returns +// the prepared VerifyParams. +// The test data has size dataSize. The data is hashed with hashAlgorithms. The +// portion to be verified ranges from verifyStart with verifySize. +func prepareVerify(t *testing.T, dataSize int64, hashAlgorithm int, dataAndTreeInSameFile bool, verifyStart int64, verifySize int64, out io.Writer) ([]byte, VerifyParams) { + t.Helper() + data := make([]byte, dataSize) + // Generate random bytes in data. + rand.Read(data) + + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + Children: make(map[string]struct{}), + HashAlgorithms: hashAlgorithm, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("could not generate Merkle tree:%v", err) + } + + return data, VerifyParams{ + Out: out, + File: bytes.NewReader(data), + Tree: &tree, + Size: dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + Children: make(map[string]struct{}), + HashAlgorithms: hashAlgorithm, + ReadOffset: verifyStart, + ReadSize: verifySize, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } +} + +func TestVerifyInvalidRange(t *testing.T) { + testCases := []struct { + name string + verifyStart int64 + verifySize int64 + }{ + // Verify range starts outside data range. + { + name: "StartOutsideRange", + verifyStart: usermem.PageSize, + verifySize: 1, + }, + // Verify range ends outside data range. + { + name: "EndOutsideRange", + verifyStart: 0, + verifySize: 2 * usermem.PageSize, + }, + // Verify range with negative size. + { + name: "NegativeSize", + verifyStart: 1, + verifySize: -1, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, linux.FS_VERITY_HASH_ALG_SHA256, false /* dataAndTreeInSameFile */, tc.verifyStart, tc.verifySize, &buf) + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyUnmodifiedMetadata(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + if _, err := Verify(¶ms); !errors.Is(err, nil) { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + }) + } +} + +func TestVerifyModifiedName(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Name += "abc" + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedSize(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Size-- + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedMode(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Mode++ + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedUID(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.UID++ + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedGID(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.GID++ + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") + } + }) + } +} + +func TestVerifyModifiedChildren(t *testing.T) { + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + _, params := prepareVerify(t, usermem.PageSize /* dataSize */, tc.hashAlgorithm, tc.dataAndTreeInSameFile, 0 /* verifyStart */, 0 /* verifySize */, &buf) + params.Children["abc"] = struct{}{} + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") } }) } } -func TestVerify(t *testing.T) { - // The input data has size dataSize. The portion to be verified ranges from - // verifyStart with verifySize. A bit is flipped in outOfRangeByteIndex to - // confirm that modifications outside the verification range does not cause - // issue. And a bit is flipped in modifyByte to confirm that - // modifications in the verification range is caught during verification. +func TestModifyOutsideVerifyRange(t *testing.T) { testCases := []struct { - dataSize int64 + name string + // The byte with index modifyByte is modified. + modifyByte int64 + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "BeforeRangeSHA256SeparateFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "BeforeRangeSHA512SeparateFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "BeforeRangeSHA256SameFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "BeforeRangeSHA512SameFile", + modifyByte: 4*usermem.PageSize - 1, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + { + name: "AfterRangeSHA256SeparateFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "AfterRangeSHA512SeparateFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "AfterRangeSHA256SameFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "AfterRangeSHA256SameFile", + modifyByte: 5 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dataSize := int64(8 * usermem.PageSize) + verifyStart := int64(4 * usermem.PageSize) + verifySize := int64(usermem.PageSize) + var buf bytes.Buffer + // Modified byte is outside verify range. Verify should succeed. + data, params := prepareVerify(t, dataSize, tc.hashAlgorithm, tc.dataAndTreeInSameFile, verifyStart, verifySize, &buf) + // Flip a bit in data and checks Verify results. + data[tc.modifyByte] ^= 1 + n, err := Verify(¶ms) + if !errors.Is(err, nil) { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != verifySize { + t.Errorf("Got Verify output size %d, want %d", n, verifySize) + } + if int64(buf.Len()) != verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), verifySize) + } + if !bytes.Equal(data[verifyStart:verifyStart+verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + }) + } +} + +func TestModifyInsideVerifyRange(t *testing.T) { + testCases := []struct { + name string verifyStart int64 verifySize int64 - // A byte in input data is modified during the test. If the - // modified byte falls in verification range, Verify should - // fail, otherwise Verify should still succeed. - modifyByte int64 - modifyName bool - modifyMode bool - modifyUID bool - modifyGID bool - shouldSucceed bool + // The byte with index modifyByte is modified. + modifyByte int64 + hashAlgorithm int + dataAndTreeInSameFile bool }{ - // Verify range start outside the data range should fail. - { - dataSize: usermem.PageSize, - verifyStart: usermem.PageSize, - verifySize: 1, - modifyByte: 0, - shouldSucceed: false, - }, - // Verifying range is valid if it starts inside data and ends - // outside data range, in that case start to the end of data is - // verified. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 2 * usermem.PageSize, - modifyByte: 0, - shouldSucceed: false, - }, - // Invalid verify range (negative size) should fail. - { - dataSize: usermem.PageSize, - verifyStart: 1, - verifySize: -1, - modifyByte: 0, - shouldSucceed: false, - }, - // 0 verify size should only verify metadata. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - shouldSucceed: true, - }, - // Modified name should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyName: true, - shouldSucceed: false, - }, - // Modified mode should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyMode: true, - shouldSucceed: false, - }, - // Modified UID should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyUID: true, - shouldSucceed: false, - }, - // Modified GID should fail verification. - { - dataSize: usermem.PageSize, - verifyStart: 0, - verifySize: 0, - modifyByte: 0, - modifyGID: true, - shouldSucceed: false, - }, - // The test cases below use a block-aligned verify range. + // Test a block-aligned verify range. // Modifying a byte in the verified range should cause verify // to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 4 * usermem.PageSize, - shouldSucceed: false, + name: "BlockAlignedRangeSHA256SeparateFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, }, - // Modifying a byte before the verified range should not cause - // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 4*usermem.PageSize - 1, - shouldSucceed: true, + name: "BlockAlignedRangeSHA512SeparateFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "BlockAlignedRangeSHA256SameFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, }, - // Modifying a byte after the verified range should not cause - // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4 * usermem.PageSize, - verifySize: usermem.PageSize, - modifyByte: 5 * usermem.PageSize, - shouldSucceed: true, + name: "BlockAlignedRangeSHA512SameFile", + verifyStart: 4 * usermem.PageSize, + verifySize: usermem.PageSize, + modifyByte: 4 * usermem.PageSize, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // The tests below use a non-block-aligned verify range. // Modifying a byte at strat of verify range should cause // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 123, - shouldSucceed: false, + name: "ModifyStartSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyStartSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyStartSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyStartSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte at the end of verify range should cause // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 123, - shouldSucceed: false, + name: "ModifyEndSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyEndSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyEndSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyEndSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte in the middle verified block should cause // verify to fail. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 5*usermem.PageSize + 123, - shouldSucceed: false, + name: "ModifyMiddleSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyMiddleSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyMiddleSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyMiddleSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 5*usermem.PageSize + 123, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte in the first block in the verified range // should cause verify to fail, even the modified bit itself is // out of verify range. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 4*usermem.PageSize + 122, - shouldSucceed: false, + name: "ModifyFirstBlockSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyFirstBlockSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyFirstBlockSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyFirstBlockSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 4*usermem.PageSize + 122, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, // Modifying a byte in the last block in the verified range // should cause verify to fail, even the modified bit itself is // out of verify range. { - dataSize: 8 * usermem.PageSize, - verifyStart: 4*usermem.PageSize + 123, - verifySize: 2 * usermem.PageSize, - modifyByte: 6*usermem.PageSize + 124, - shouldSucceed: false, + name: "ModifyLastBlockSHA256SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyLastBlockSHA512SeparateFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "ModifyLastBlockSHA256SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "ModifyLastBlockSHA512SameFile", + verifyStart: 4*usermem.PageSize + 123, + verifySize: 2 * usermem.PageSize, + modifyByte: 6*usermem.PageSize + 124, + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, }, } - for _, tc := range testCases { - t.Run(fmt.Sprintf("%d", tc.modifyByte), func(t *testing.T) { - data := make([]byte, tc.dataSize) - // Generate random bytes in data. - rand.Read(data) - - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, - } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - - // Flip a bit in data and checks Verify results. - var buf bytes.Buffer - data[tc.modifyByte] ^= 1 - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: tc.dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: tc.verifyStart, - ReadSize: tc.verifySize, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if tc.modifyName { - verifyParams.Name = defaultName + "abc" - } - if tc.modifyMode { - verifyParams.Mode = defaultMode + 1 - } - if tc.modifyUID { - verifyParams.UID = defaultUID + 1 - } - if tc.modifyGID { - verifyParams.GID = defaultGID + 1 - } - if tc.shouldSucceed { - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed when expected to succeed: %v", err) - } - if n != tc.verifySize { - t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) - } - if int64(buf.Len()) != tc.verifySize { - t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) - } - if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } - } else { - if _, err := Verify(&verifyParams); err == nil { - t.Errorf("Verification succeeded when expected to fail") - } - } + t.Run(tc.name, func(t *testing.T) { + dataSize := int64(8 * usermem.PageSize) + var buf bytes.Buffer + data, params := prepareVerify(t, dataSize, tc.hashAlgorithm, tc.dataAndTreeInSameFile, tc.verifyStart, tc.verifySize, &buf) + // Flip a bit in data and checks Verify results. + data[tc.modifyByte] ^= 1 + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Errorf("Verification succeeded when expected to fail") } }) } } func TestVerifyRandom(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - // Use a random dataSize. Minimum size 2 so that we can pick a random - // portion from it. - dataSize := rand.Int63n(200*usermem.PageSize) + 2 - data := make([]byte, dataSize) - // Generate random bytes in data. - rand.Read(data) - - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, - } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } + testCases := []struct { + name string + hashAlgorithm int + dataAndTreeInSameFile bool + }{ + { + name: "SHA256SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: false, + }, + { + name: "SHA512SeparateFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + }, + { + name: "SHA256SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + }, + { + name: "SHA512SameFile", + hashAlgorithm: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + // Use a random dataSize. Minimum size 2 so that we can pick a random + // portion from it. + dataSize := rand.Int63n(200*usermem.PageSize) + 2 - // Pick a random portion of data. - start := rand.Int63n(dataSize - 1) - size := rand.Int63n(dataSize) + 1 + // Pick a random portion of data. + start := rand.Int63n(dataSize - 1) + size := rand.Int63n(dataSize) + 1 - var buf bytes.Buffer - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: start, - ReadSize: size, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + var buf bytes.Buffer + data, params := prepareVerify(t, dataSize, tc.hashAlgorithm, tc.dataAndTreeInSameFile, start, size, &buf) - // Checks that the random portion of data from the original data is - // verified successfully. - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed for correct data: %v", err) - } - if size > dataSize-start { - size = dataSize - start - } - if n != size { - t.Errorf("Got Verify output size %d, want %d", n, size) - } - if int64(buf.Len()) != size { - t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) - } - if !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } + // Checks that the random portion of data from the original data is + // verified successfully. + n, err := Verify(¶ms) + if err != nil && err != io.EOF { + t.Errorf("Verification failed for correct data: %v", err) + } + if size > dataSize-start { + size = dataSize - start + } + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } - // Verify that modified metadata should fail verification. - buf.Reset() - verifyParams.Name = defaultName + "abc" - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verify succeeded for modified metadata, expect failure") - } + // Verify that modified metadata should fail verification. + buf.Reset() + params.Name = defaultName + "abc" + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Error("Verify succeeded for modified metadata, expect failure") + } - // Flip a random bit in randPortion, and check that verification fails. - buf.Reset() - randBytePos := rand.Int63n(size) - data[start+randBytePos] ^= 1 - verifyParams.File = bytes.NewReader(data) - verifyParams.Name = defaultName + // Flip a random bit in randPortion, and check that verification fails. + buf.Reset() + randBytePos := rand.Int63n(size) + data[start+randBytePos] ^= 1 + params.File = bytes.NewReader(data) + params.Name = defaultName - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verification succeeded for modified data, expect failure") - } + if _, err := Verify(¶ms); errors.Is(err, nil) { + t.Error("Verification succeeded for modified data, expect failure") + } + }) } } diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 71e944c30..eadea390a 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -570,6 +570,8 @@ func (c *Client) Version() uint32 { func (c *Client) Close() { // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already // been called (by c.watch()). - c.socket.Shutdown() + if err := c.socket.Shutdown(); err != nil { + log.Warningf("Socket.Shutdown() failed (FD: %d): %v", c.socket.FD(), err) + } c.closedWg.Wait() } diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 28fe081d6..8b46a2987 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -478,28 +478,23 @@ func (r *ReadWriterFile) ReadAt(p []byte, offset int64) (int, error) { } // Write implements part of the io.ReadWriter interface. +// +// Note that this may return a short write with a nil error. This violates the +// contract of io.Writer, but is more consistent with gVisor's pattern of +// returning errors that correspond to Linux errnos. Since short writes without +// error are common in Linux, returning a nil error is appropriate. func (r *ReadWriterFile) Write(p []byte) (int, error) { n, err := r.File.WriteAt(p, r.Offset) r.Offset += uint64(n) - if err != nil { - return n, err - } - if n < len(p) { - return n, io.ErrShortWrite - } - return n, nil + return n, err } // WriteAt implements the io.WriteAt interface. +// +// Note that this may return a short write with a nil error. This violates the +// contract of io.WriterAt. See comment on Write for justification. func (r *ReadWriterFile) WriteAt(p []byte, offset int64) (int, error) { - n, err := r.File.WriteAt(p, uint64(offset)) - if err != nil { - return n, err - } - if n < len(p) { - return n, io.ErrShortWrite - } - return n, nil + return r.File.WriteAt(p, uint64(offset)) } // Rename implements File.Rename. diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index abd237f46..81ceb37c5 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -296,25 +296,6 @@ func (t *Tlopen) handle(cs *connState) message { } defer ref.DecRef() - ref.openedMu.Lock() - defer ref.openedMu.Unlock() - - // Has it been opened already? - if ref.opened || !CanOpen(ref.mode) { - return newErr(syscall.EINVAL) - } - - if ref.mode.IsDir() { - // Directory must be opened ReadOnly. - if t.Flags&OpenFlagsModeMask != ReadOnly { - return newErr(syscall.EISDIR) - } - // Directory not truncatable. - if t.Flags&OpenTruncate != 0 { - return newErr(syscall.EISDIR) - } - } - var ( qid QID ioUnit uint32 @@ -326,6 +307,22 @@ func (t *Tlopen) handle(cs *connState) message { return syscall.EINVAL } + // Has it been opened already? + if ref.opened || !CanOpen(ref.mode) { + return syscall.EINVAL + } + + if ref.mode.IsDir() { + // Directory must be opened ReadOnly. + if t.Flags&OpenFlagsModeMask != ReadOnly { + return syscall.EISDIR + } + // Directory not truncatable. + if t.Flags&OpenTruncate != 0 { + return syscall.EISDIR + } + } + osFile, qid, ioUnit, err = ref.file.Open(t.Flags) return err }); err != nil { @@ -366,7 +363,7 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) { } // Not allowed on open directories. - if _, opened := ref.OpenFlags(); opened { + if ref.opened { return syscall.EINVAL } @@ -437,7 +434,7 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) { } // Not allowed on open directories. - if _, opened := ref.OpenFlags(); opened { + if ref.opened { return syscall.EINVAL } @@ -476,7 +473,7 @@ func (t *Tlink) handle(cs *connState) message { } // Not allowed on open directories. - if _, opened := ref.OpenFlags(); opened { + if ref.opened { return syscall.EINVAL } @@ -518,7 +515,7 @@ func (t *Trenameat) handle(cs *connState) message { } // Not allowed on open directories. - if _, opened := ref.OpenFlags(); opened { + if ref.opened { return syscall.EINVAL } @@ -561,7 +558,7 @@ func (t *Tunlinkat) handle(cs *connState) message { } // Not allowed on open directories. - if _, opened := ref.OpenFlags(); opened { + if ref.opened { return syscall.EINVAL } @@ -701,13 +698,12 @@ func (t *Tread) handle(cs *connState) message { ) if err := ref.safelyRead(func() (err error) { // Has it been opened already? - openFlags, opened := ref.OpenFlags() - if !opened { + if !ref.opened { return syscall.EINVAL } // Can it be read? Check permissions. - if openFlags&OpenFlagsModeMask == WriteOnly { + if ref.openFlags&OpenFlagsModeMask == WriteOnly { return syscall.EPERM } @@ -731,13 +727,12 @@ func (t *Twrite) handle(cs *connState) message { var n int if err := ref.safelyRead(func() (err error) { // Has it been opened already? - openFlags, opened := ref.OpenFlags() - if !opened { + if !ref.opened { return syscall.EINVAL } // Can it be written? Check permissions. - if openFlags&OpenFlagsModeMask == ReadOnly { + if ref.openFlags&OpenFlagsModeMask == ReadOnly { return syscall.EPERM } @@ -778,7 +773,7 @@ func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) { } // Not allowed on open directories. - if _, opened := ref.OpenFlags(); opened { + if ref.opened { return syscall.EINVAL } @@ -820,7 +815,7 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) { } // Not allowed on open directories. - if _, opened := ref.OpenFlags(); opened { + if ref.opened { return syscall.EINVAL } @@ -898,13 +893,12 @@ func (t *Tallocate) handle(cs *connState) message { if err := ref.safelyWrite(func() error { // Has it been opened already? - openFlags, opened := ref.OpenFlags() - if !opened { + if !ref.opened { return syscall.EINVAL } // Can it be written? Check permissions. - if openFlags&OpenFlagsModeMask == ReadOnly { + if ref.openFlags&OpenFlagsModeMask == ReadOnly { return syscall.EBADF } @@ -1049,8 +1043,8 @@ func (t *Treaddir) handle(cs *connState) message { return syscall.EINVAL } - // Has it been opened already? - if _, opened := ref.OpenFlags(); !opened { + // Has it been opened yet? + if !ref.opened { return syscall.EINVAL } @@ -1076,8 +1070,8 @@ func (t *Tfsync) handle(cs *connState) message { defer ref.DecRef() if err := ref.safelyRead(func() (err error) { - // Has it been opened already? - if _, opened := ref.OpenFlags(); !opened { + // Has it been opened yet? + if !ref.opened { return syscall.EINVAL } @@ -1185,8 +1179,13 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI } // Has it been opened already? - if _, opened := ref.OpenFlags(); opened { - err = syscall.EBUSY + err = ref.safelyRead(func() (err error) { + if ref.opened { + return syscall.EBUSY + } + return nil + }) + if err != nil { return } diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e605b14c..2e3d427ae 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -678,16 +678,15 @@ func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string // case. defer checkDeleted(h, dst) } else { + // If the type is different than the destination, then + // we expect the rename to fail. We expect that this + // is returned. + // + // If the file being renamed to itself, this is + // technically allowed and a no-op, but all the + // triggers will fire. if !selfRename { - // If the type is different than the - // destination, then we expect the rename to - // fail. We expect ensure that this is - // returned. expectedErr = syscall.EINVAL - } else { - // This is the file being renamed to itself. - // This is technically allowed and a no-op, but - // all the triggers will fire. } dst.Close() } diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 3736f12a3..8c5c434fd 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -134,12 +134,11 @@ type fidRef struct { // The node above will be closed only when refs reaches zero. refs int64 - // openedMu protects opened and openFlags. - openedMu sync.Mutex - // opened indicates whether this has been opened already. // // This is updated in handlers.go. + // + // opened is protected by pathNode.opMu or renameMu (for write). opened bool // mode is the fidRef's mode from the walk. Only the type bits are @@ -151,6 +150,8 @@ type fidRef struct { // openFlags is the mode used in the open. // // This is updated in handlers.go. + // + // openFlags is protected by pathNode.opMu or renameMu (for write). openFlags OpenFlags // pathNode is the current pathNode for this FID. @@ -177,13 +178,6 @@ type fidRef struct { deleted uint32 } -// OpenFlags returns the flags the file was opened with and true iff the fid was opened previously. -func (f *fidRef) OpenFlags() (OpenFlags, bool) { - f.openedMu.Lock() - defer f.openedMu.Unlock() - return f.openFlags, f.opened -} - // IncRef increases the references on a fid. func (f *fidRef) IncRef() { atomic.AddInt64(&f.refs, 1) diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index e7406b374..a29f06ddb 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -197,33 +197,33 @@ func BenchmarkSendRecv(b *testing.B) { for i := 0; i < b.N; i++ { tag, m, err := recv(server, maximumLength, msgRegistry.get) if err != nil { - b.Fatalf("recv got err %v expected nil", err) + b.Errorf("recv got err %v expected nil", err) } if tag != Tag(1) { - b.Fatalf("got tag %v expected 1", tag) + b.Errorf("got tag %v expected 1", tag) } if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %T expected *Rflush", m) + b.Errorf("got message %T expected *Rflush", m) } if err := send(server, Tag(2), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) + b.Errorf("send got err %v expected nil", err) } } }() b.ResetTimer() for i := 0; i < b.N; i++ { if err := send(client, Tag(1), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) + b.Errorf("send got err %v expected nil", err) } tag, m, err := recv(client, maximumLength, msgRegistry.get) if err != nil { - b.Fatalf("recv got err %v expected nil", err) + b.Errorf("recv got err %v expected nil", err) } if tag != Tag(2) { - b.Fatalf("got tag %v expected 2", tag) + b.Errorf("got tag %v expected 2", tag) } if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %v expected *Rflush", m) + b.Errorf("got message %v expected *Rflush", m) } } } diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index a1b2e0cfe..54e825b28 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package pool provides a trivial integer pool. package pool import ( diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 699ea8ac3..6992e1de8 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -319,7 +319,8 @@ func makeStackKey(pcs []uintptr) stackKey { return key } -func recordStack() []uintptr { +// RecordStack constructs and returns the PCs on the current stack. +func RecordStack() []uintptr { pcs := make([]uintptr, maxStackFrames) n := runtime.Callers(1, pcs) if n == 0 { @@ -342,7 +343,8 @@ func recordStack() []uintptr { return v } -func formatStack(pcs []uintptr) string { +// FormatStack converts the given stack into a readable format. +func FormatStack(pcs []uintptr) string { frames := runtime.CallersFrames(pcs) var trace bytes.Buffer for { @@ -367,7 +369,7 @@ func (r *AtomicRefCount) finalize() { if n := r.ReadRefs(); n != 0 { msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n) if len(r.stack) != 0 { - msg += ":\nCaller:\n" + formatStack(r.stack) + msg += ":\nCaller:\n" + FormatStack(r.stack) } else { msg += " (enable trace logging to debug)" } @@ -392,7 +394,7 @@ func (r *AtomicRefCount) EnableLeakCheck(name string) { case NoLeakChecking: return case LeaksLogTraces: - r.stack = recordStack() + r.stack = RecordStack() } r.name = name runtime.SetFinalizer(r, (*AtomicRefCount).finalize) diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD index 577b827a5..0377c0876 100644 --- a/pkg/refs_vfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -8,6 +8,9 @@ go_template( srcs = [ "refs_template.go", ], + opt_consts = [ + "enableLogging", + ], types = [ "T", ], @@ -19,8 +22,16 @@ go_template( ) go_library( - name = "refs_vfs2", - srcs = ["refs.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/context"], + name = "refsvfs2", + srcs = [ + "refs.go", + "refs_map.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/context", + "//pkg/log", + "//pkg/refs", + "//pkg/sync", + ], ) diff --git a/pkg/refsvfs2/README.md b/pkg/refsvfs2/README.md new file mode 100644 index 000000000..eca53c282 --- /dev/null +++ b/pkg/refsvfs2/README.md @@ -0,0 +1,66 @@ +# Reference Counting + +Go does not offer a reliable way to couple custom resource management with +object lifetime. As a result, we need to manually implement reference counting +for many objects in gVisor to make sure that resources are acquired and released +appropriately. For example, the filesystem has many reference-counted objects +(file descriptions, dentries, inodes, etc.), and it is important that each +object persists while anything holds a reference on it and is destroyed once all +references are dropped. + +We provide a template in `refs_template.go` that can be applied to most objects +in need of reference counting. It contains a simple `Refs` struct that can be +incremented and decremented, and once the reference count reaches zero, a +destructor can be called. Note that there are some objects (e.g. `gofer.dentry`, +`overlay.dentry`) that should not immediately be destroyed upon reaching zero +references; in these cases, this template cannot be applied. + +# Reference Checking + +Unfortunately, manually keeping track of reference counts is extremely error +prone, and improper accounting can lead to production bugs that are very +difficult to root cause. + +We have several ways of discovering reference count errors in gVisor. Any +attempt to increment/decrement a `Refs` struct with a count of zero will trigger +a sentry panic, since the object should have been destroyed and become +unreachable. This allows us to identify missing increments or extra decrements, +which cause the reference count to be lower than it should be: the count will +reach zero earlier than expected, and the next increment/decrement--which should +be valid--will result in a panic. + +It is trickier to identify extra increments and missing decrements, which cause +the reference count to be higher than expected (i.e. a “reference leak”). +Reference leaks prevent resources from being released properly and can translate +to various issues that are tricky to diagnose, such as memory leaks. The +following section discusses how we implement leak checking. + +## Leak Checking + +When leak checking is enabled, reference-counted objects are added to a global +map when constructed and removed when destroyed. Near the very end of sandbox +execution, once no reference-counted objects should still be reachable, we +report everything left in the map as having leaked. Leak-checking objects +implement the `CheckedObject` interface, which allows us to print informative +warnings for each of the leaked objects. + +Leak checking is provided by `refs_template`, but objects that do not use the +template will also need to implement `CheckedObject` and be manually +registered/unregistered from the map in order to be checked. + +Note that leak checking affects performance and memory usage, so it should only +be enabled in testing environments. + +## Debugging + +Even with the checks described above, it can be difficult to track down the +exact source of a reference counting error. The error may occur far before it is +discovered (for instance, a missing `IncRef` may not be discovered until a +future `DecRef` makes the count negative). To aid in debugging, `refs_template` +provides the `enableLogging` option to log every `IncRef`, `DecRef`, and leak +check registration/unregistration, along with the object address and a call +stack. This allows us to search a log for all of the changes to a particular +object's reference count, which makes it much easier to identify the absent or +extraneous operation(s). The reference-counted objects that do not use +`refs_template` also provide logging, and others defined in the future should do +so as well. diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go index 99a074e96..ef8beb659 100644 --- a/pkg/refs_vfs2/refs.go +++ b/pkg/refsvfs2/refs.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package refs_vfs2 defines an interface for a reference-counted object. -package refs_vfs2 +// Package refsvfs2 defines an interface for a reference-counted object. +package refsvfs2 import ( "gvisor.dev/gvisor/pkg/context" diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go new file mode 100644 index 000000000..9fbc5466f --- /dev/null +++ b/pkg/refsvfs2/refs_map.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package refsvfs2 + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" +) + +var ( + // liveObjects is a global map of reference-counted objects. Objects are + // inserted when leak check is enabled, and they are removed when they are + // destroyed. It is protected by liveObjectsMu. + liveObjects map[CheckedObject]struct{} + liveObjectsMu sync.Mutex +) + +// CheckedObject represents a reference-counted object with an informative +// leak detection message. +type CheckedObject interface { + // RefType is the type of the reference-counted object. + RefType() string + + // LeakMessage supplies a warning to be printed upon leak detection. + LeakMessage() string + + // LogRefs indicates whether reference-related events should be logged. + LogRefs() bool +} + +func init() { + liveObjects = make(map[CheckedObject]struct{}) +} + +// leakCheckEnabled returns whether leak checking is enabled. The following +// functions should only be called if it returns true. +func leakCheckEnabled() bool { + return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking +} + +// Register adds obj to the live object map. +func Register(obj CheckedObject) { + if leakCheckEnabled() { + liveObjectsMu.Lock() + if _, ok := liveObjects[obj]; ok { + panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj)) + } + liveObjects[obj] = struct{}{} + liveObjectsMu.Unlock() + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "registered") + } + } +} + +// Unregister removes obj from the live object map. +func Unregister(obj CheckedObject) { + if leakCheckEnabled() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + if _, ok := liveObjects[obj]; !ok { + panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj)) + } + delete(liveObjects, obj) + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "unregistered") + } + } +} + +// LogIncRef logs a reference increment. +func LogIncRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("IncRef to %d", refs)) + } +} + +// LogTryIncRef logs a successful TryIncRef call. +func LogTryIncRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs)) + } +} + +// LogDecRef logs a reference decrement. +func LogDecRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("DecRef to %d", refs)) + } +} + +// logEvent logs a message for the given reference-counted object. +// +// obj.LogRefs() should be checked before calling logEvent, in order to avoid +// calling any text processing needed to evaluate msg. +func logEvent(obj CheckedObject, msg string) { + log.Infof("[%s %p] %s:", obj.RefType(), obj, msg) + log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack())) +} + +// DoLeakCheck iterates through the live object map and logs a message for each +// object. It is called once no reference-counted objects should be reachable +// anymore, at which point anything left in the map is considered a leak. +func DoLeakCheck() { + if leakCheckEnabled() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + leaked := len(liveObjects) + if leaked > 0 { + log.Warningf("Leak checking detected %d leaked objects:", leaked) + for obj := range liveObjects { + log.Warningf(obj.LeakMessage()) + } + } + } +} diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index d9b552896..3fbc91aa5 100644 --- a/pkg/refs_vfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -13,40 +13,33 @@ // limitations under the License. // Package refs_template defines a template that can be used by reference -// counted objects. The "owner" template parameter is used in log messages to -// indicate the type of reference-counted object that exhibited a reference -// leak. As a result, structs that are embedded in other structs should not use -// this template, since it will make tracking down leaks more difficult. +// counted objects. package refs_template import ( "fmt" - "runtime" "sync/atomic" - "gvisor.dev/gvisor/pkg/log" - refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" ) +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const enableLogging = false + // T is the type of the reference counted object. It is only used to customize // debug output when leak checking. type T interface{} -// ownerType is used to customize logging. Note that we use a pointer to T so -// that we do not copy the entire object when passed as a format parameter. -var ownerType *T +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var obj *T // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. // -// Note that the number of references is actually refCount + 1 so that a default -// zero-value Refs object contains one reference. -// -// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in -// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. -// This will allow us to add stack trace information to the leak messages -// without growing the size of Refs. -// // +stateify savable type Refs struct { // refCount is composed of two fields: @@ -59,39 +52,44 @@ type Refs struct { refCount int64 } -func (r *Refs) finalize() { - var note string - switch refs_vfs1.GetLeakMode() { - case refs_vfs1.NoLeakChecking: - return - case refs_vfs1.UninitializedLeakChecking: - note = "(Leak checker uninitialized): " - } - if n := r.ReadRefs(); n != 0 { - log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) - } +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *Refs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) } -// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. -func (r *Refs) EnableLeakCheck() { - if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { - runtime.SetFinalizer(r, (*Refs).finalize) - } +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *Refs) RefType() string { + return fmt.Sprintf("%T", obj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *Refs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *Refs) LogRefs() bool { + return enableLogging } // ReadRefs returns the current number of references. The returned count is // inherently racy and is unsafe to use without external synchronization. func (r *Refs) ReadRefs() int64 { - // Account for the internal -1 offset on refcounts. - return atomic.LoadInt64(&r.refCount) + 1 + return atomic.LoadInt64(&r.refCount) } // IncRef implements refs.RefCounter.IncRef. // //go:nosplit func (r *Refs) IncRef() { - if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { - panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType)) + v := atomic.AddInt64(&r.refCount, 1) + if enableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) } } @@ -104,15 +102,17 @@ func (r *Refs) IncRef() { //go:nosplit func (r *Refs) TryIncRef() bool { const speculativeRef = 1 << 32 - v := atomic.AddInt64(&r.refCount, speculativeRef) - if int32(v) < 0 { + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { // This object has already been freed. atomic.AddInt64(&r.refCount, -speculativeRef) return false } // Turn into a real reference. - atomic.AddInt64(&r.refCount, -speculativeRef+1) + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + if enableLogging { + refsvfs2.LogTryIncRef(r, v) + } return true } @@ -129,14 +129,25 @@ func (r *Refs) TryIncRef() bool { // //go:nosplit func (r *Refs) DecRef(destroy func()) { - switch v := atomic.AddInt64(&r.refCount, -1); { - case v < -1: - panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType)) + v := atomic.AddInt64(&r.refCount, -1) + if enableLogging { + refsvfs2.LogDecRef(r, v) + } + switch { + case v < 0: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) - case v == -1: + case v == 0: + refsvfs2.Unregister(r) // Call the destructor. if destroy != nil { destroy() } } } + +func (r *Refs) afterLoad() { + if r.ReadRefs() > 0 { + refsvfs2.Register(r) + } +} diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go index e7fd30743..7857f5853 100644 --- a/pkg/safemem/block_unsafe.go +++ b/pkg/safemem/block_unsafe.go @@ -68,29 +68,29 @@ func blockFromSlice(slice []byte, needSafecopy bool) Block { } } -// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is +// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+length), which is // safe to access without safecopy. // -// Preconditions: ptr+len does not overflow. -func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, false) +// Preconditions: ptr+length does not overflow. +func BlockFromSafePointer(ptr unsafe.Pointer, length int) Block { + return blockFromPointer(ptr, length, false) } // BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which // is not safe to access without safecopy. // // Preconditions: ptr+len does not overflow. -func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, true) +func BlockFromUnsafePointer(ptr unsafe.Pointer, length int) Block { + return blockFromPointer(ptr, length, true) } -func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block { - if uptr := uintptr(ptr); uptr+uintptr(len) < uptr { - panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len)) +func blockFromPointer(ptr unsafe.Pointer, length int, needSafecopy bool) Block { + if uptr := uintptr(ptr); uptr+uintptr(length) < uptr { + panic(fmt.Sprintf("ptr %#x + len %#x overflows", uptr, length)) } return Block{ start: ptr, - length: len, + length: length, needSafecopy: needSafecopy, } } diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index 752e2dc32..ec17ebc4d 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -79,7 +79,7 @@ func Install(rules SyscallRules) error { // Perform the actual installation. if errno := SetFilter(instrs); errno != 0 { - return fmt.Errorf("Failed to set filter: %v", errno) + return fmt.Errorf("failed to set filter: %v", errno) } log.Infof("Seccomp filters installed.") diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go index 7cd895cc7..652c010da 100644 --- a/pkg/segment/test/set_functions.go +++ b/pkg/segment/test/set_functions.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package segment is a test package. package segment type setFunctions struct{} diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index c9fb55d00..35d2e07c3 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() { } } -// Pid returns the si_pid field. -func (s *SignalInfo) Pid() int32 { +// PID returns the si_pid field. +func (s *SignalInfo) PID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[0:4])) } -// SetPid mutates the si_pid field. -func (s *SignalInfo) SetPid(val int32) { +// SetPID mutates the si_pid field. +func (s *SignalInfo) SetPID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) } -// Uid returns the si_uid field. -func (s *SignalInfo) Uid() int32 { +// UID returns the si_uid field. +func (s *SignalInfo) UID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[4:8])) } -// SetUid mutates the si_uid field. -func (s *SignalInfo) SetUid(val int32) { +// SetUID mutates the si_uid field. +func (s *SignalInfo) SetUID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) } @@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 { func (s *SignalInfo) SetArch(val uint32) { usermem.ByteOrder.PutUint32(s.Fields[12:16], val) } + +// Band returns the si_band field. +func (s *SignalInfo) Band() int64 { + return int64(usermem.ByteOrder.Uint64(s.Fields[0:8])) +} + +// SetBand mutates the si_band field. +func (s *SignalInfo) SetBand(val int64) { + // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. + // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is + // `int`. See siginfo.h. + usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) +} + +// FD returns the si_fd field. +func (s *SignalInfo) FD() uint32 { + return usermem.ByteOrder.Uint32(s.Fields[8:12]) +} + +// SetFD mutates the si_fd field. +func (s *SignalInfo) SetFD(val uint32) { + usermem.ByteOrder.PutUint32(s.Fields[8:12], val) +} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 2bf3c45e1..91b8fb44f 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -193,7 +193,7 @@ func (p *Profile) StopTrace(_, _ *struct{}) error { defer p.mu.Unlock() if p.traceFile == nil { - return errors.New("Execution tracing not started") + return errors.New("execution tracing not started") } // Similarly to the case above, if tasks have not ended traces, we will diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index 41feeffe3..62eaca965 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { Callback: func(err error) { if err == nil { log.Infof("Save succeeded: exiting...") + s.Kernel.SetSaveSuccess(false /* autosave */) } else { log.Warningf("Save failed: exiting...") s.Kernel.SetSaveError(err) @@ -69,5 +70,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { s.Kernel.Kill(kernel.ExitStatus{}) }, } - return saveOpts.Save(s.Kernel, s.Watchdog) + return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog) } diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD index 14a8bf9cd..71c59287c 100644 --- a/pkg/sentry/devices/tundev/BUILD +++ b/pkg/sentry/devices/tundev/BUILD @@ -17,7 +17,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 655ea549b..d8f4e1d35 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -16,8 +16,6 @@ package tundev import ( - "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -28,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -39,6 +36,8 @@ const ( ) // tunDevice implements vfs.Device for /dev/net/tun. +// +// +stateify savable type tunDevice struct{} // Open implements vfs.Device.Open. @@ -53,6 +52,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt } // tunFD implements vfs.FileDescriptionImpl for /dev/net/tun. +// +// +stateify savable type tunFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -87,16 +88,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - created, err := fd.device.SetIff(stack.Stack, req.Name(), flags) - if err == nil && created { - // Always start with an ARP address for interfaces so they can handle ARP - // packets. - nicID := fd.device.NICID() - if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) - } - } - return 0, err + return 0, fd.device.SetIff(stack.Stack, req.Name(), flags) case linux.TUNGETIFF: var req linux.IFReq diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index 314661475..badd5b073 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package fdimport provides the Import function. package fdimport import ( diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index ea85ab33c..5c3e852e9 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -49,13 +49,13 @@ go_library( "//pkg/amutex", "//pkg/context", "//pkg/log", - "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/secio", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs/lock", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index ff2fe6712..8e0aa9019 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err // copyUpBuffers is a buffer pool for copying file content. The buffer // size is the same used by io.Copy. -var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }} +var copyUpBuffers = sync.Pool{ + New: func() interface{} { + b := make([]byte, 8*usermem.PageSize) + return &b + }, +} // copyContentsLocked copies the contents of lower to upper. It panics if // less than size bytes can be copied. @@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in defer lowerFile.DecRef(ctx) // Use a buffer pool to minimize allocations. - buf := copyUpBuffers.Get().([]byte) + buf := copyUpBuffers.Get().(*[]byte) defer copyUpBuffers.Put(buf) // Transfer the contents. @@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in // optimizations could be self-defeating. So we leave this as simple as possible. var offset int64 for { - nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset) + nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset) if err != nil && err != io.EOF { return err } @@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in } return nil } - nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset) + nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset) if err != nil { return err } diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index c7a11eec1..e04784db2 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) { wg.Add(1) go func(o *overlayTestFile) { if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Fatalf("failed to copy up: %v", err) + t.Errorf("failed to copy up: %v", err) } wg.Done() }(file) diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 6b7b451b8..9379a4d7b 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -34,7 +34,6 @@ go_library( "//pkg/sentry/socket/netstack", "//pkg/syserror", "//pkg/tcpip/link/tun", - "//pkg/tcpip/network/arp", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 19ffdec47..5227ef652 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -15,8 +15,6 @@ package dev import ( - "fmt" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -27,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -112,16 +109,7 @@ func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io user return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - created, err := n.device.SetIff(stack.Stack, req.Name(), flags) - if err == nil && created { - // Always start with an ARP address for interfaces so they can handle ARP - // packets. - nicID := n.device.NICID() - if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) - } - } - return 0, err + return 0, n.device.SetIff(stack.Stack, req.Name(), flags) case linux.TUNGETIFF: var req linux.IFReq diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 72ea70fcf..57f904801 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -17,13 +17,12 @@ package fs import ( "math" "sync/atomic" - "time" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" @@ -33,28 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var ( - // RecordWaitTime controls writing metrics for filesystem reads. - // Enabling this comes at a small CPU cost due to performing two - // monotonic clock reads per read call. - // - // Note that this is only performed in the direct read path, and may - // not be consistently applied for other forms of reads, such as - // splice. - RecordWaitTime = false - - reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.") - readWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.") -) - -// IncrementWait increments the given wait time metric, if enabled. -func IncrementWait(m *metric.Uint64Metric, start time.Time) { - if !RecordWaitTime { - return - } - m.IncrementBy(uint64(time.Since(start))) -} - // FileMaxOffset is the maximum possible file offset. const FileMaxOffset = math.MaxInt64 @@ -257,22 +234,19 @@ func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error { // // Returns syserror.ErrInterrupted if reading was interrupted. func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) { - var start time.Time - if RecordWaitTime { - start = time.Now() - } + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.ReadWait, start) + if !f.mu.Lock(ctx) { - IncrementWait(readWait, start) return 0, syserror.ErrInterrupted } - reads.Increment() + fsmetric.Reads.Increment() n, err := f.FileOperations.Read(ctx, f, dst, f.offset) if n > 0 && !f.flags.NonSeekable { atomic.AddInt64(&f.offset, n) } f.mu.Unlock() - IncrementWait(readWait, start) return n, err } @@ -282,19 +256,16 @@ func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) // // Otherwise same as Readv. func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { - var start time.Time - if RecordWaitTime { - start = time.Now() - } + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.ReadWait, start) + if !f.mu.Lock(ctx) { - IncrementWait(readWait, start) return 0, syserror.ErrInterrupted } - reads.Increment() + fsmetric.Reads.Increment() n, err := f.FileOperations.Read(ctx, f, dst, offset) f.mu.Unlock() - IncrementWait(readWait, start) return n, err } diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go index 8049538f2..ec3d3f96c 100644 --- a/pkg/sentry/fs/filetest/filetest.go +++ b/pkg/sentry/fs/filetest/filetest.go @@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File { // Read just fails the request. func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Readv not implemented") + return 0, fmt.Errorf("TestFileOperations.Read not implemented") } // Write just fails the request. func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Writev not implemented") + return 0, fmt.Errorf("TestFileOperations.Write not implemented") } diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index d2dbff268..a020da53b 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -65,7 +65,7 @@ var ( // runs with the lock held for reading. AsyncBarrier will take the lock // for writing, thus ensuring that all Async work completes before // AsyncBarrier returns. - workMu sync.RWMutex + workMu sync.CrossGoroutineRWMutex // asyncError is used to store up to one asynchronous execution error. asyncError = make(chan error, 1) diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 1390a9a7f..4468f5dd2 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() { f.mappings = make(map[uint64]mapping) } +// IsInited returns true if f.Init() has been called. This is used when +// restoring a checkpoint that contains a HostFileMapper that may or may not +// have been initialized. +func (f *HostFileMapper) IsInited() bool { + return f.refs != nil +} + // NewHostFileMapper returns an initialized HostFileMapper allocated on the // heap with no references or cached mappings. func NewHostFileMapper() *HostFileMapper { diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index fea135eea..4c30098cd 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -28,7 +28,6 @@ go_library( "//pkg/context", "//pkg/fd", "//pkg/log", - "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/safemem", @@ -38,6 +37,7 @@ go_library( "//pkg/sentry/fs/fdpipe", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/host", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go index d481baf77..e5579095b 100644 --- a/pkg/sentry/fs/gofer/attr.go +++ b/pkg/sentry/fs/gofer/attr.go @@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType { return fs.BlockDevice case pattr.Mode.IsSocket(): return fs.Socket - case pattr.Mode.IsRegular(): - fallthrough default: return fs.RegularFile } diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index c0bc63a32..bb63448cb 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -21,27 +21,17 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) -var ( - opensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a writable+executable file was opened from a gofer.") - opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.") - opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.") - reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") - readWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.") - readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.") - readWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.") -) - // fileOperations implements fs.FileOperations for a remote file system. // // +stateify savable @@ -101,14 +91,14 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } if flags.Write { if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil { - opensWX.Increment() + fsmetric.GoferOpensWX.Increment() log.Warningf("Opened a writable executable: %q", name) } } if handles.Host != nil { - opensHost.Increment() + fsmetric.GoferOpensHost.Increment() } else { - opens9P.Increment() + fsmetric.GoferOpens9P.Increment() } return fs.NewFile(ctx, dirent, flags, f) } @@ -278,20 +268,17 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I // use this function rather than using a defer in Read() to avoid the performance hit of defer. func (f *fileOperations) incrementReadCounters(start time.Time) { if f.handles.Host != nil { - readsHost.Increment() - fs.IncrementWait(readWaitHost, start) + fsmetric.GoferReadsHost.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start) } else { - reads9P.Increment() - fs.IncrementWait(readWait9P, start) + fsmetric.GoferReads9P.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start) } } // Read implements fs.FileOperations.Read. func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var start time.Time - if fs.RecordWaitTime { - start = time.Now() - } + start := fsmetric.StartReadWait() if fs.IsDir(file.Dirent.Inode.StableAttr) { // Not all remote file systems enforce this so this client does. f.incrementReadCounters(start) diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 3a225fd39..9d6fdd08f 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -117,7 +117,7 @@ type inodeFileState struct { // loading is acquired when the inodeFileState begins an asynchronous // load. It releases when the load is complete. Callers that require all // state to be available should call waitForLoad() to ensure that. - loading sync.Mutex `state:".(struct{})"` + loading sync.CrossGoroutineMutex `state:".(struct{})"` // savedUAttr is only allocated during S/R. It points to the save-time // unstable attributes and is used to validate restore-time ones. diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 3c66dc3c2..6b3627813 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // maxFilenameLen is the maximum length of a filename. This is dictated by 9P's @@ -305,7 +304,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, } // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Wrap the fileOps with our Fifo. iops := &fifo{ diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 004910453..9b3d8166a 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -18,9 +18,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" @@ -28,8 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -var opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.") - // Inode is a file system object that can be simultaneously referenced by different // components of the VFS (Dirent, fs.File, etc). // @@ -247,7 +245,7 @@ func (i *Inode) GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File, if i.overlay != nil { return overlayGetFile(ctx, i.overlay, d, flags) } - opens.Increment() + fsmetric.Opens.Increment() return i.InodeOperations.GetFile(ctx, d, flags) } diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go index 8fe626e1c..e6171dd1d 100644 --- a/pkg/sentry/fs/proc/exec_args.go +++ b/pkg/sentry/fs/proc/exec_args.go @@ -57,16 +57,16 @@ type execArgInode struct { var _ fs.InodeOperations = (*execArgInode)(nil) // newExecArgFile creates a file containing the exec args of the given type. -func newExecArgInode(t *kernel.Task, msrc *fs.MountSource, arg execArgType) *fs.Inode { +func newExecArgInode(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, arg execArgType) *fs.Inode { if arg != cmdlineExecArg && arg != environExecArg { panic(fmt.Sprintf("unknown exec arg type %v", arg)) } f := &execArgInode{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), arg: arg, t: t, } - return newProcInode(t, f, msrc, fs.SpecialFile, t) + return newProcInode(ctx, f, msrc, fs.SpecialFile, t) } // GetFile implements fs.InodeOperations.GetFile. diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go index 45523adf8..e90da225a 100644 --- a/pkg/sentry/fs/proc/fds.go +++ b/pkg/sentry/fs/proc/fds.go @@ -95,13 +95,13 @@ var _ fs.InodeOperations = (*fd)(nil) // newFd returns a new fd based on an existing file. // // This inherits one reference to the file. -func newFd(t *kernel.Task, f *fs.File, msrc *fs.MountSource) *fs.Inode { +func newFd(ctx context.Context, t *kernel.Task, f *fs.File, msrc *fs.MountSource) *fs.Inode { fd := &fd{ // RootOwner overridden by taskOwnedInodeOps.UnstableAttrs(). - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""), file: f, } - return newProcInode(t, fd, msrc, fs.Symlink, t) + return newProcInode(ctx, fd, msrc, fs.Symlink, t) } // GetFile returns the fs.File backing this fd. The dirent and flags @@ -153,12 +153,12 @@ type fdDir struct { var _ fs.InodeOperations = (*fdDir)(nil) // newFdDir creates a new fdDir. -func newFdDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newFdDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { f := &fdDir{ - Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermissions{User: fs.PermMask{Read: true, Execute: true}}), + Dir: *ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermissions{User: fs.PermMask{Read: true, Execute: true}}), t: t, } - return newProcInode(t, f, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, f, msrc, fs.SpecialDirectory, t) } // Check implements InodeOperations.Check. @@ -183,7 +183,7 @@ func (f *fdDir) Check(ctx context.Context, inode *fs.Inode, req fs.PermMask) boo // Lookup loads an Inode in /proc/TID/fd into a Dirent. func (f *fdDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dirent, error) { n, err := walkDescriptors(f.t, p, func(file *fs.File, _ kernel.FDFlags) *fs.Inode { - return newFd(f.t, file, dir.MountSource) + return newFd(ctx, f.t, file, dir.MountSource) }) if err != nil { return nil, err @@ -237,12 +237,12 @@ type fdInfoDir struct { } // newFdInfoDir creates a new fdInfoDir. -func newFdInfoDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newFdInfoDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { fdid := &fdInfoDir{ - Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0500)), + Dir: *ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0500)), t: t, } - return newProcInode(t, fdid, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, fdid, msrc, fs.SpecialDirectory, t) } // Lookup loads an fd in /proc/TID/fdinfo into a Dirent. diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 83a43aa26..03127f816 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -41,7 +41,7 @@ import ( // LINT.IfChange // newNetDir creates a new proc net entry. -func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newNetDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { k := t.Kernel() var contents map[string]*fs.Inode @@ -49,39 +49,39 @@ func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]*fs.Inode{ - "dev": seqfile.NewSeqFileInode(t, &netDev{s: s}, msrc), - "snmp": seqfile.NewSeqFileInode(t, &netSnmp{s: s}, msrc), + "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), + "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), // The following files are simple stubs until they are // implemented in netstack, if the file contains a // header the stub is just the header otherwise it is // an empty file. - "arp": newStaticProcInode(t, msrc, []byte("IP address HW type Flags HW address Mask Device\n")), + "arp": newStaticProcInode(ctx, msrc, []byte("IP address HW type Flags HW address Mask Device\n")), - "netlink": newStaticProcInode(t, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")), - "netstat": newStaticProcInode(t, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")), - "packet": newStaticProcInode(t, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")), - "protocols": newStaticProcInode(t, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")), + "netlink": newStaticProcInode(ctx, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")), + "netstat": newStaticProcInode(ctx, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")), + "packet": newStaticProcInode(ctx, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")), + "protocols": newStaticProcInode(ctx, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")), // Linux sets psched values to: nsec per usec, psched // tick in ns, 1000000, high res timer ticks per sec // (ClockGetres returns 1ns resolution). - "psched": newStaticProcInode(t, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), - "ptype": newStaticProcInode(t, msrc, []byte("Type Device Function\n")), - "route": seqfile.NewSeqFileInode(t, &netRoute{s: s}, msrc), - "tcp": seqfile.NewSeqFileInode(t, &netTCP{k: k}, msrc), - "udp": seqfile.NewSeqFileInode(t, &netUDP{k: k}, msrc), - "unix": seqfile.NewSeqFileInode(t, &netUnix{k: k}, msrc), + "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))), + "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function\n")), + "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc), + "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc), + "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc), + "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc), } if s.SupportsIPv6() { - contents["if_inet6"] = seqfile.NewSeqFileInode(t, &ifinet6{s: s}, msrc) - contents["ipv6_route"] = newStaticProcInode(t, msrc, []byte("")) - contents["tcp6"] = seqfile.NewSeqFileInode(t, &netTCP6{k: k}, msrc) - contents["udp6"] = newStaticProcInode(t, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")) + contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc) + contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte("")) + contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc) + contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")) } } - d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) - return newProcInode(t, d, msrc, fs.SpecialDirectory, t) + d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) + return newProcInode(ctx, d, msrc, fs.SpecialDirectory, t) } // ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6. diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go index 77e0e1d26..2f2a9f920 100644 --- a/pkg/sentry/fs/proc/proc.go +++ b/pkg/sentry/fs/proc/proc.go @@ -179,7 +179,7 @@ func (p *proc) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dire } // Wrap it in a taskDir. - td := p.newTaskDir(otherTask, dir.MountSource, true) + td := p.newTaskDir(ctx, otherTask, dir.MountSource, true) return fs.NewDirent(ctx, td, name), nil } diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index f8aad2dbd..b998fb75d 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode children := map[string]*fs.Inode{ "hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil), + "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))), "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))), "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))), diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index e555672ad..52061175f 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -86,9 +86,9 @@ func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { } // GetFile implements fs.InodeOperations.GetFile. -func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { +func (t *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true - return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil + return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: t}), nil } // +stateify savable diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 22d658acf..f43d6c221 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -79,44 +79,45 @@ type taskDir struct { var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. -func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { +func (p *proc) newTaskDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "cwd": newCwd(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - "io": newIO(t, msrc, isThreadGroup), - "maps": newMaps(t, msrc), - "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), - "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), - "net": newNetDir(t, msrc), - "ns": newNamespaceDir(t, msrc), - "oom_score": newOOMScore(t, msrc), - "oom_score_adj": newOOMScoreAdj(t, msrc), - "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), - "statm": newStatm(t, msrc), - "status": newStatus(t, msrc, p.pidns), - "uid_map": newUIDMap(t, msrc), + "auxv": newAuxvec(ctx, t, msrc), + "cmdline": newExecArgInode(ctx, t, msrc, cmdlineExecArg), + "comm": newComm(ctx, t, msrc), + "cwd": newCwd(ctx, t, msrc), + "environ": newExecArgInode(ctx, t, msrc, environExecArg), + "exe": newExe(ctx, t, msrc), + "fd": newFdDir(ctx, t, msrc), + "fdinfo": newFdInfoDir(ctx, t, msrc), + "gid_map": newGIDMap(ctx, t, msrc), + "io": newIO(ctx, t, msrc, isThreadGroup), + "maps": newMaps(ctx, t, msrc), + "mem": newMem(ctx, t, msrc), + "mountinfo": seqfile.NewSeqFileInode(ctx, &mountInfoFile{t: t}, msrc), + "mounts": seqfile.NewSeqFileInode(ctx, &mountsFile{t: t}, msrc), + "net": newNetDir(ctx, t, msrc), + "ns": newNamespaceDir(ctx, t, msrc), + "oom_score": newOOMScore(ctx, msrc), + "oom_score_adj": newOOMScoreAdj(ctx, t, msrc), + "smaps": newSmaps(ctx, t, msrc), + "stat": newTaskStat(ctx, t, msrc, isThreadGroup, p.pidns), + "statm": newStatm(ctx, t, msrc), + "status": newStatus(ctx, t, msrc, p.pidns), + "uid_map": newUIDMap(ctx, t, msrc), } if isThreadGroup { - contents["task"] = p.newSubtasks(t, msrc) + contents["task"] = p.newSubtasks(ctx, t, msrc) } if len(p.cgroupControllers) > 0 { - contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers) + contents["cgroup"] = newCGroupInode(ctx, msrc, p.cgroupControllers) } // N.B. taskOwnedInodeOps enforces dumpability-based ownership. d := &taskDir{ - Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)), + Dir: *ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)), t: t, } - return newProcInode(t, d, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, d, msrc, fs.SpecialDirectory, t) } // subtasks represents a /proc/TID/task directory. @@ -131,13 +132,13 @@ type subtasks struct { var _ fs.InodeOperations = (*subtasks)(nil) -func (p *proc) newSubtasks(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func (p *proc) newSubtasks(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { s := &subtasks{ - Dir: *ramfs.NewDir(t, nil, fs.RootOwner, fs.FilePermsFromMode(0555)), + Dir: *ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555)), t: t, p: p, } - return newProcInode(t, s, msrc, fs.SpecialDirectory, t) + return newProcInode(ctx, s, msrc, fs.SpecialDirectory, t) } // UnstableAttr returns unstable attributes of the subtasks. @@ -242,7 +243,7 @@ func (s *subtasks) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs.Dir return nil, syserror.ENOENT } - td := s.p.newTaskDir(task, dir.MountSource, false) + td := s.p.newTaskDir(ctx, task, dir.MountSource, false) return fs.NewDirent(ctx, td, p), nil } @@ -255,12 +256,12 @@ type exe struct { t *kernel.Task } -func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newExe(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { exeSymlink := &exe{ - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""), t: t, } - return newProcInode(t, exeSymlink, msrc, fs.Symlink, t) + return newProcInode(ctx, exeSymlink, msrc, fs.Symlink, t) } func (e *exe) executable() (file fsbridge.File, err error) { @@ -310,12 +311,12 @@ type cwd struct { t *kernel.Task } -func newCwd(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newCwd(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { cwdSymlink := &cwd{ - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, ""), t: t, } - return newProcInode(t, cwdSymlink, msrc, fs.Symlink, t) + return newProcInode(ctx, cwdSymlink, msrc, fs.Symlink, t) } // Readlink implements fs.InodeOperations. @@ -354,17 +355,17 @@ type namespaceSymlink struct { t *kernel.Task } -func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs.Inode { +func newNamespaceSymlink(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, name string) *fs.Inode { // TODO(rahat): Namespace symlinks should contain the namespace name and the // inode number for the namespace instance, so for example user:[123456]. We // currently fake the inode number by sticking the symlink inode in its // place. target := fmt.Sprintf("%s:[%d]", name, device.ProcDevice.NextIno()) n := &namespaceSymlink{ - Symlink: *ramfs.NewSymlink(t, fs.RootOwner, target), + Symlink: *ramfs.NewSymlink(ctx, fs.RootOwner, target), t: t, } - return newProcInode(t, n, msrc, fs.Symlink, t) + return newProcInode(ctx, n, msrc, fs.Symlink, t) } // Readlink reads the symlink value. @@ -389,14 +390,96 @@ func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Di return fs.NewDirent(ctx, newProcInode(ctx, iops, inode.MountSource, fs.RegularFile, nil), n.Symlink.Target), nil } -func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newNamespaceDir(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { contents := map[string]*fs.Inode{ - "net": newNamespaceSymlink(t, msrc, "net"), - "pid": newNamespaceSymlink(t, msrc, "pid"), - "user": newNamespaceSymlink(t, msrc, "user"), + "net": newNamespaceSymlink(ctx, t, msrc, "net"), + "pid": newNamespaceSymlink(ctx, t, msrc, "pid"), + "user": newNamespaceSymlink(ctx, t, msrc, "user"), } - d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0511)) - return newProcInode(t, d, msrc, fs.SpecialDirectory, t) + d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0511)) + return newProcInode(ctx, d, msrc, fs.SpecialDirectory, t) +} + +// memData implements fs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memData struct { + fsutil.SimpleFileInode + + t *kernel.Task +} + +// memDataFile implements fs.FileOperations for /proc/[pid]/mem. +// +// +stateify savable +type memDataFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + t *kernel.Task +} + +func newMem(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + inode := &memData{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC), + t: t, + } + return newProcInode(ctx, inode, msrc, fs.SpecialFile, t) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (m *memData) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, m.t, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(m.t); err != nil { + return nil, err + } + // Enable random access reads + flags.Pread = true + return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil +} + +// Read implements fs.FileOperations.Read. +func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + mm, err := getTaskMM(m.t) + if err != nil { + return 0, nil + } + defer mm.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil } // mapsData implements seqfile.SeqSource for /proc/[pid]/maps. @@ -406,8 +489,8 @@ type mapsData struct { t *kernel.Task } -func newMaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &mapsData{t}), msrc, fs.SpecialFile, t) +func newMaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &mapsData{t}), msrc, fs.SpecialFile, t) } func (md *mapsData) mm() *mm.MemoryManager { @@ -446,8 +529,8 @@ type smapsData struct { t *kernel.Task } -func newSmaps(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &smapsData{t}), msrc, fs.SpecialFile, t) +func newSmaps(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &smapsData{t}), msrc, fs.SpecialFile, t) } func (sd *smapsData) mm() *mm.MemoryManager { @@ -492,8 +575,8 @@ type taskStatData struct { pidns *kernel.PIDNamespace } -func newTaskStat(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool, pidns *kernel.PIDNamespace) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &taskStatData{t, showSubtasks /* tgstats */, pidns}), msrc, fs.SpecialFile, t) +func newTaskStat(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, showSubtasks bool, pidns *kernel.PIDNamespace) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &taskStatData{t, showSubtasks /* tgstats */, pidns}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -577,8 +660,8 @@ type statmData struct { t *kernel.Task } -func newStatm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &statmData{t}), msrc, fs.SpecialFile, t) +func newStatm(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &statmData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. @@ -614,8 +697,8 @@ type statusData struct { pidns *kernel.PIDNamespace } -func newStatus(t *kernel.Task, msrc *fs.MountSource, pidns *kernel.PIDNamespace) *fs.Inode { - return newProcInode(t, seqfile.NewSeqFile(t, &statusData{t, pidns}), msrc, fs.SpecialFile, t) +func newStatus(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, pidns *kernel.PIDNamespace) *fs.Inode { + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &statusData{t, pidns}), msrc, fs.SpecialFile, t) } // NeedsUpdate implements seqfile.SeqSource.NeedsUpdate. @@ -685,11 +768,11 @@ type ioData struct { ioUsage } -func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { +func newIO(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { if isThreadGroup { - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t) } - return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t) + return newProcInode(ctx, seqfile.NewSeqFile(ctx, &ioData{t}), msrc, fs.SpecialFile, t) } // NeedsUpdate returns whether the generation is old or not. @@ -733,12 +816,12 @@ type comm struct { } // newComm returns a new comm file. -func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newComm(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { c := &comm{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), t: t, } - return newProcInode(t, c, msrc, fs.SpecialFile, t) + return newProcInode(ctx, c, msrc, fs.SpecialFile, t) } // Check implements fs.InodeOperations.Check. @@ -805,12 +888,12 @@ type auxvec struct { } // newAuxvec returns a new auxvec file. -func newAuxvec(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newAuxvec(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { a := &auxvec{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC), t: t, } - return newProcInode(t, a, msrc, fs.SpecialFile, t) + return newProcInode(ctx, a, msrc, fs.SpecialFile, t) } // GetFile implements fs.InodeOperations.GetFile. @@ -866,8 +949,8 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc // newOOMScore returns a oom_score file. It is a stub that always returns 0. // TODO(gvisor.dev/issue/1967) -func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newStaticProcInode(t, msrc, []byte("0\n")) +func newOOMScore(ctx context.Context, msrc *fs.MountSource) *fs.Inode { + return newStaticProcInode(ctx, msrc, []byte("0\n")) } // oomScoreAdj is a file containing the oom_score adjustment for a task. @@ -896,12 +979,12 @@ type oomScoreAdjFile struct { } // newOOMScoreAdj returns a oom_score_adj file. -func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { +func newOOMScoreAdj(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { i := &oomScoreAdj{ - SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), t: t, } - return newProcInode(t, i, msrc, fs.SpecialFile, t) + return newProcInode(ctx, i, msrc, fs.SpecialFile, t) } // Truncate implements fs.InodeOperations.Truncate. Truncate is called when diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go index 8d9517b95..2bc9485d8 100644 --- a/pkg/sentry/fs/proc/uid_gid_map.go +++ b/pkg/sentry/fs/proc/uid_gid_map.go @@ -58,18 +58,18 @@ type idMapInodeOperations struct { var _ fs.InodeOperations = (*idMapInodeOperations)(nil) // newUIDMap returns a new uid_map file. -func newUIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newIDMap(t, msrc, false /* gids */) +func newUIDMap(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newIDMap(ctx, t, msrc, false /* gids */) } // newGIDMap returns a new gid_map file. -func newGIDMap(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { - return newIDMap(t, msrc, true /* gids */) +func newGIDMap(ctx context.Context, t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newIDMap(ctx, t, msrc, true /* gids */) } -func newIDMap(t *kernel.Task, msrc *fs.MountSource, gids bool) *fs.Inode { - return newProcInode(t, &idMapInodeOperations{ - InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), +func newIDMap(ctx context.Context, t *kernel.Task, msrc *fs.MountSource, gids bool) *fs.Inode { + return newProcInode(ctx, &idMapInodeOperations{ + InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), t: t, gids: gids, }, msrc, fs.SpecialFile, t) diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index aa7199014..b521a86a2 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -15,12 +15,12 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", - "//pkg/metric", "//pkg/safemem", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index fc0498f17..e04cd608d 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -18,14 +18,13 @@ import ( "fmt" "io" "math" - "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -35,13 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -var ( - opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.") - opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.") - reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.") - readWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.") -) - // fileInodeOperations implements fs.InodeOperations for a regular tmpfs file. // These files are backed by pages allocated from a platform.Memory, and may be // directly mapped. @@ -157,9 +149,9 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare // GetFile implements fs.InodeOperations.GetFile. func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { if flags.Write { - opensW.Increment() + fsmetric.TmpfsOpensW.Increment() } else if flags.Read { - opensRO.Increment() + fsmetric.TmpfsOpensRO.Increment() } flags.Pread = true flags.Pwrite = true @@ -319,14 +311,12 @@ func (*fileInodeOperations) StatFS(context.Context) (fs.Info, error) { } func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var start time.Time - if fs.RecordWaitTime { - start = time.Now() - } - reads.Increment() + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start) + fsmetric.TmpfsReads.Increment() + // Zero length reads for tmpfs are no-ops. if dst.NumBytes() == 0 { - fs.IncrementWait(readWait, start) return 0, nil } @@ -343,7 +333,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm size := f.attr.Size f.dataMu.RUnlock() if offset >= size { - fs.IncrementWait(readWait, start) return 0, io.EOF } @@ -354,7 +343,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm f.attr.AccessTime = ktime.NowFromContext(ctx) f.attrMu.Unlock() } - fs.IncrementWait(readWait, start) return n, err } @@ -431,9 +419,6 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // Continue. seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil @@ -532,9 +517,6 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) // Write to that memory as usual. seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 998b697ca..cf4ed5de0 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -336,7 +336,7 @@ type Fifo struct { // NewFifo creates a new named pipe. func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Build pipe InodeOperations. iops := pipe.NewInodeOperations(ctx, perms, p) diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 84baaac66..6af3c3781 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "root_inode_refs.go", package = "devpts", prefix = "rootInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "rootInode", }, @@ -33,6 +33,7 @@ go_library( "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index d5c5aaa8c..d8c237753 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir } fstype.initOnce.Do(func() { - fs, root, err := fstype.newFilesystem(vfsObj, creds) + fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds) if err != nil { fstype.initErr = err return @@ -93,7 +93,7 @@ type filesystem struct { // newFilesystem creates a new devpts filesystem with root directory and ptmx // master inode. It returns the filesystem and root Dentry. -func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { +func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err @@ -108,19 +108,19 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds root := &rootInode{ replicas: make(map[uint32]*replicaInode), } - root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) + root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - root.EnableLeakCheck() + root.InitRefs() var rootD kernfs.Dentry - rootD.Init(&fs.Filesystem, root) + rootD.InitRoot(&fs.Filesystem, root) // Construct the pts master inode and dentry. Linux always uses inode // id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx. master := &masterInode{ root: root, } - master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) + master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) // Add the master as a child of the root. links := root.OrderedChildren.Populate(map[string]kernfs.Inode{ @@ -170,7 +170,7 @@ type rootInode struct { var _ kernfs.Inode = (*rootInode)(nil) // allocateTerminal creates a new Terminal and installs a pts node for it. -func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) { +func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) { i.mu.Lock() defer i.mu.Unlock() if i.nextIdx == math.MaxUint32 { @@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) i.replicas[idx] = replica return t, nil @@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro } // IterDirents implements kernfs.Inode.IterDirents. -func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() + i.InodeAttrs.TouchAtime(ctx, mnt) ids := make([]int, 0, len(i.replicas)) for id := range i.replicas { ids = append(ids, int(id)) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index e6b0e81cf..ae95fdd08 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -100,10 +100,10 @@ type lineDiscipline struct { column int // masterWaiter is used to wait on the master end of the TTY. - masterWaiter waiter.Queue `state:"zerovalue"` + masterWaiter waiter.Queue // replicaWaiter is used to wait on the replica end of the TTY. - replicaWaiter waiter.Queue `state:"zerovalue"` + replicaWaiter waiter.Queue } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index fda30fb93..e91fa26a4 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil) // Open implements kernfs.Inode.Open. func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - t, err := mi.root.allocateTerminal(rp.Credentials()) + t, err := mi.root.allocateTerminal(ctx, rp.Credentials()) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD index 01bbee5ad..e49a04c1b 100644 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ b/pkg/sentry/fsimpl/devtmpfs/BUILD @@ -4,7 +4,10 @@ licenses(["notice"]) go_library( name = "devtmpfs", - srcs = ["devtmpfs.go"], + srcs = [ + "devtmpfs.go", + "save_restore.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/goid/goid_race.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go index 1766beaee..28832d850 100644 --- a/pkg/goid/goid_race.go +++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Only available in race/gotsan builds. -// +build race +package devtmpfs -// Package goid provides access to the ID of the current goroutine in -// race/gotsan builds. -package goid - -// Get returns the ID of the current goroutine. -func Get() int64 { - return goid() +// afterLoad is invoked by stateify. +func (fst *FilesystemType) afterLoad() { + if fst.fs != nil { + // Ensure that we don't create another filesystem. + fst.initOnce.Do(func() {}) + } } diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 1c27ad700..5b29f2358 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -43,7 +43,7 @@ type EventFileDescription struct { // queue is used to notify interested parties when the event object // becomes readable or writable. - queue waiter.Queue `state:"zerovalue"` + queue waiter.Queue // mu protects the fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 045d7ab08..2158b1bbc 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "inode_refs.go", package = "fuse", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -49,6 +49,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/kernfs", diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go index 8ccda1264..34d25a61e 100644 --- a/pkg/sentry/fsimpl/fuse/connection.go +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -193,11 +192,12 @@ func (conn *connection) loadInitializedChan(closed bool) { } } -// newFUSEConnection creates a FUSE connection to fd. -func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, opts *filesystemOptions) (*connection, error) { - // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to - // mount a FUSE filesystem. - fuseFD := fd.Impl().(*DeviceFD) +// newFUSEConnection creates a FUSE connection to fuseFD. +func newFUSEConnection(_ context.Context, fuseFD *DeviceFD, opts *filesystemOptions) (*connection, error) { + // Mark the device as ready so it can be used. + // FIXME(gvisor.dev/issue/4813): fuseFD's fields are accessed without + // synchronization and without checking if fuseFD has already been used to + // mount another filesystem. // Create the writeBuf for the header to be stored in. hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go index bfde78559..4ab894965 100644 --- a/pkg/sentry/fsimpl/fuse/connection_control.go +++ b/pkg/sentry/fsimpl/fuse/connection_control.go @@ -84,11 +84,7 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { Flags: fuseDefaultInitFlags, } - req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) - if err != nil { - return err - } - + req := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) // Since there is no task to block on and FUSE_INIT is the request // to unblock other requests, use nil. return conn.CallAsync(nil, req) @@ -198,7 +194,6 @@ func (conn *connection) Abort(ctx context.Context) { if !conn.connected { conn.asyncMu.Unlock() conn.mu.Unlock() - conn.fd.mu.Unlock() return } diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 91d16c1cf..d8b0d7657 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -76,10 +76,7 @@ func TestConnectionAbort(t *testing.T) { var futNormal []*futureResponse for i := 0; i < int(numRequests); i++ { - req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) fut, err := conn.callFutureLocked(task, req) if err != nil { t.Fatalf("callFutureLocked failed: %v", err) @@ -105,10 +102,7 @@ func TestConnectionAbort(t *testing.T) { } // After abort, Call() should return directly with ENOTCONN. - req, err := conn.NewRequest(creds, 0, 0, 0, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, 0, 0, 0, testObj) _, err = conn.Call(task, req) if err != syserror.ENOTCONN { t.Fatalf("Incorrect error code received for Call() after connection aborted") diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index 1b86a4b4c..1bbe6fdb7 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -94,7 +94,8 @@ type DeviceFD struct { // unprocessed in-flight requests. fullQueueCh chan struct{} `state:".(int)"` - // fs is the FUSE filesystem that this FD is being used for. + // fs is the FUSE filesystem that this FD is being used for. A reference is + // held on fs. fs *filesystem } @@ -135,12 +136,6 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R return 0, syserror.EPERM } - // Return ENODEV if the filesystem is umounted. - if fd.fs.umounted { - // TODO(gvisor.dev/issue/3525): return ECONNABORTED if aborted via fuse control fs. - return 0, syserror.ENODEV - } - // We require that any Read done on this filesystem have a sane minimum // read buffer. It must have the capacity for the fixed parts of any request // header (Linux uses the request header and the FUSEWriteIn header for this @@ -368,7 +363,7 @@ func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask { var ready waiter.EventMask - if fd.fs.umounted { + if fd.fs == nil || fd.fs.umounted { ready |= waiter.EventErr return ready & mask } diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 5986133e9..bb2d0d31a 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -219,10 +219,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con data: rand.Uint32(), } - req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) // Queue up a request. // Analogous to Call except it doesn't block on the task. @@ -315,7 +312,7 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F readPayload.MarshalUnsafe(outBuf[outHdrLen:]) outIOseq := usermem.BytesIOSequence(outBuf) - n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) if err != nil { t.Fatalf("Write failed :%v", err) } diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go index 8f220a04b..fcc5d9a2a 100644 --- a/pkg/sentry/fsimpl/fuse/directory.go +++ b/pkg/sentry/fsimpl/fuse/directory.go @@ -68,11 +68,7 @@ func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirent } // TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS. - req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) - if err != nil { - return err - } - + req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) res, err := fusefs.conn.Call(task, req) if err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go index 83f2816b7..e138b11f8 100644 --- a/pkg/sentry/fsimpl/fuse/file.go +++ b/pkg/sentry/fsimpl/fuse/file.go @@ -83,12 +83,8 @@ func (fd *fileDescription) Release(ctx context.Context) { opcode = linux.FUSE_RELEASE } kernelTask := kernel.TaskFromContext(ctx) - // ignoring errors and FUSE server reply is analogous to Linux's behavior. - req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) - if err != nil { - // No way to invoke Call() with an errored request. - return - } + // Ignoring errors and FUSE server reply is analogous to Linux's behavior. + req := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) // The reply will be ignored since no callback is defined in asyncCallBack(). conn.CallAsync(kernelTask, req) } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index e39df21c6..3af807a21 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -119,7 +119,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) if err != nil { - return nil, nil, err + log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err) + return nil, nil, syserror.EINVAL } kernelTask := kernel.TaskFromContext(ctx) @@ -127,7 +128,13 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name()) return nil, nil, syserror.EINVAL } - fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + defer fuseFDGeneric.DecRef(ctx) + fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD) + if !ok { + log.Warningf("%s.GetFilesystem: device FD is %T, not a FUSE device", fsType.Name, fuseFDGeneric) + return nil, nil, syserror.EINVAL + } // Parse and set all the other supported FUSE mount options. // TODO(gVisor.dev/issue/3229): Expand the supported mount options. @@ -189,44 +196,47 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // Create a new FUSE filesystem. - fs, err := newFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + fs, err := newFUSEFilesystem(ctx, vfsObj, &fsType, fuseFD, devMinor, &fsopts) if err != nil { log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) return nil, nil, err } - fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - // Send a FUSE_INIT request to the FUSE daemon server before returning. // This call is not blocking. if err := fs.conn.InitSend(creds, uint32(kernelTask.ThreadID())); err != nil { log.Warningf("%s.InitSend: failed with error: %v", fsType.Name(), err) + fs.VFSFilesystem().DecRef(ctx) // returned by newFUSEFilesystem return nil, nil, err } // root is the fusefs root directory. - root := fs.newRootInode(creds, fsopts.rootMode) + root := fs.newRoot(ctx, creds, fsopts.rootMode) return fs.VFSFilesystem(), root.VFSDentry(), nil } // newFUSEFilesystem creates a new FUSE filesystem. -func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { - conn, err := newFUSEConnection(ctx, device, opts) +func newFUSEFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, fsType *FilesystemType, fuseFD *DeviceFD, devMinor uint32, opts *filesystemOptions) (*filesystem, error) { + conn, err := newFUSEConnection(ctx, fuseFD, opts) if err != nil { log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) return nil, syserror.EINVAL } - fuseFD := device.Impl().(*DeviceFD) - fs := &filesystem{ devMinor: devMinor, opts: opts, conn: conn, } + fs.VFSFilesystem().Init(vfsObj, fsType, fs) - fs.VFSFilesystem().IncRef() + // FIXME(gvisor.dev/issue/4813): Doesn't conn or fs need to hold a + // reference on fuseFD, since conn uses fuseFD for communication with the + // server? Wouldn't doing so create a circular reference? + fs.VFSFilesystem().IncRef() // for fuseFD.fs + // FIXME(gvisor.dev/issue/4813): fuseFD.fs is accessed without + // synchronization. fuseFD.fs = fs return fs, nil @@ -284,24 +294,24 @@ type inode struct { link string } -func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { +func (fs *filesystem) newRoot(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { i := &inode{fs: fs, nodeID: 1} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) + i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - i.EnableLeakCheck() + i.InitRefs() var d kernfs.Dentry - d.Init(&fs.Filesystem, i) + d.InitRoot(&fs.Filesystem, i) return &d } -func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { +func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { i := &inode{fs: fs, nodeID: nodeID} creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} - i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) atomic.StoreUint64(&i.size, attr.Size) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - i.EnableLeakCheck() + i.InitRefs() return i } @@ -351,12 +361,8 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr in.Flags &= ^uint32(linux.O_TRUNC) } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) - if err != nil { - return nil, err - } - // Send the request and receive the reply. + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return nil, err @@ -424,7 +430,7 @@ func (i *inode) Keep() bool { } // IterDirents implements kernfs.Inode.IterDirents. -func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { return offset, nil } @@ -476,10 +482,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err return syserror.EINVAL } in := linux.FUSEUnlinkIn{Name: name} - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) - if err != nil { - return err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return err @@ -506,11 +509,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) in := linux.FUSERmDirIn{Name: name} - req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) - if err != nil { - return err - } - + req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { return err @@ -526,10 +525,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) return nil, syserror.EINVAL } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) - if err != nil { - return nil, err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return nil, err @@ -544,7 +540,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { return nil, syserror.EIO } - child := i.fs.newInode(out.NodeID, out.Attr) + child := i.fs.newInode(ctx, out.NodeID, out.Attr) return child, nil } @@ -565,10 +561,7 @@ func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context") return "", syserror.EINVAL } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) - if err != nil { - return "", err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return "", err @@ -671,11 +664,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp GetAttrFlags: flags, Fh: fh, } - req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) - if err != nil { - return linux.FUSEAttr{}, err - } - + req := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { return linux.FUSEAttr{}, err @@ -696,7 +685,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return linux.FUSEAttr{}, err @@ -794,11 +783,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre UID: opts.Stat.UID, GID: opts.Stat.GID, } - req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) - if err != nil { - return err - } - + req := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) res, err := conn.Call(task, req) if err != nil { return err @@ -812,7 +797,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 625d1547f..23ce91849 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -79,13 +79,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui in.Offset = off + (uint64(pagesRead) << usermem.PageShift) in.Size = pagesCanRead << usermem.PageShift - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) - if err != nil { - return nil, 0, err - } - // TODO(gvisor.dev/issue/3247): support async read. + req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) res, err := fs.conn.Call(t, req) if err != nil { return nil, 0, err @@ -132,7 +128,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u // May need to update the signature. i := fd.inode() - // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount()) // Reached EOF. if sizeRead < size { @@ -179,6 +175,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, Flags: fd.statusFlags(), } + inode := fd.inode() var written uint32 // This loop is intended for fragmented write where the bytes to write is @@ -203,11 +200,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, in.Offset = off + uint64(written) in.Size = toWrite - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) - if err != nil { - return 0, err - } - + req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) req.payload = data[written : written+toWrite] // TODO(gvisor.dev/issue/3247): support async write. @@ -237,6 +230,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, break } } + inode.InodeAttrs.TouchCMtime(ctx) return written, nil } diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index 7fa00569b..41d679358 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) src = src[2:] } + _ = src // Remove unused warning. } // SizeBytes is the size of the payload of the FUSE_INIT response. @@ -104,7 +105,7 @@ type Request struct { } // NewRequest creates a new request that can be sent to the FUSE server. -func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { +func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) *Request { conn.fd.mu.Lock() defer conn.fd.mu.Unlock() conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) @@ -130,7 +131,7 @@ func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint id: hdr.Unique, hdr: &hdr, data: buf, - }, nil + } } // futureResponse represents an in-flight request, that may or may not have diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go index e1d9e3365..2c0cc0f4e 100644 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -52,27 +52,21 @@ func setup(t *testing.T) *testutil.System { // newTestConnection creates a fuse connection that the sentry can communicate with // and the FD for the server to communicate with. func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { - vfsObj := &vfs.VirtualFilesystem{} fuseDev := &DeviceFD{} - if err := vfsObj.Init(system.Ctx); err != nil { - return nil, nil, err - } - - vd := vfsObj.NewAnonVirtualDentry("genCountFD") + vd := system.VFS.NewAnonVirtualDentry("fuse") defer vd.DecRef(system.Ctx) - if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { return nil, nil, err } fsopts := filesystemOptions{ maxActiveRequests: maxActiveRequests, } - fs, err := newFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + fs, err := newFUSEFilesystem(system.Ctx, system.VFS, &FilesystemType{}, fuseDev, 0, &fsopts) if err != nil { return nil, nil, err } - return fs.conn, &fuseDev.vfsfd, nil } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index ad0afc41b..807b6ed1f 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "save_restore.go", "socket.go", "special_file.go", "symlink.go", @@ -53,10 +54,12 @@ go_library( "//pkg/log", "//pkg/p9", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/host", + "//pkg/sentry/fsmetric", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -70,6 +73,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 18c884b59..3b5927702 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -16,16 +16,17 @@ package gofer import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -92,14 +93,17 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { child := &dentry{ refs: 1, // held by d fs: d.fs, - ino: d.fs.nextSyntheticIno(), + ino: d.fs.nextIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), blockSize: usermem.PageSize, // arbitrary - hostFD: -1, + readFD: -1, + writeFD: -1, + mmapFD: -1, nlink: uint32(2), } + refsvfs2.Register(child) switch opts.mode.FileType() { case linux.S_IFDIR: // Nothing else needs to be done. @@ -235,7 +239,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: uint64(inoFromPath(p9d.QID.Path)), + Ino: d.fs.inoFromQIDPath(p9d.QID.Path), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 94d96261b..df27554d3 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -24,18 +24,18 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { - // Snapshot current syncable dentries and special files. + // Snapshot current syncable dentries and special file FDs. fs.syncMu.Lock() ds := make([]*dentry, 0, len(fs.syncableDentries)) for d := range fs.syncableDentries { @@ -53,22 +53,28 @@ func (fs *filesystem) Sync(ctx context.Context) error { // regardless. var retErr error - // Sync regular files. + // Sync syncable dentries. for _, d := range ds { - err := d.syncCachedFile(ctx) + err := d.syncCachedFile(ctx, true /* forFilesystemSync */) d.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) + if retErr == nil { + retErr = err + } } } // Sync special files, which may be writable but do not use dentry shared // handles (so they won't be synced by the above). for _, sffd := range sffds { - err := sffd.Sync(ctx) + err := sffd.sync(ctx, true /* forFilesystemSync */) sffd.vfsfd.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) + if retErr == nil { + retErr = err + } } } @@ -109,6 +115,51 @@ func putDentrySlice(ds *[]*dentry) { dentrySlicePool.Put(ds) } +// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls +// dentry.checkCachingLocked on all dentries in *dsp with fs.renameMu locked +// for writing. +// +// dsp is a pointer-to-pointer since defer evaluates its arguments immediately, +// but dentry slices are allocated lazily, and it's much easier to say "defer +// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { +// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. +func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp **[]*dentry) { + fs.renameMu.RUnlock() + if *dsp == nil { + return + } + ds := **dsp + // Only go through calling dentry.checkCachingLocked() (which requires + // re-locking renameMu) if we actually have any dentries with zero refs. + checkAny := false + for i := range ds { + if atomic.LoadInt64(&ds[i].refs) == 0 { + checkAny = true + break + } + } + if checkAny { + fs.renameMu.Lock() + for _, d := range ds { + d.checkCachingLocked(ctx) + } + fs.renameMu.Unlock() + } + putDentrySlice(*dsp) +} + +func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { + if *ds == nil { + fs.renameMu.Unlock() + return + } + for _, d := range **ds { + d.checkCachingLocked(ctx) + } + fs.renameMu.Unlock() + putDentrySlice(*ds) +} + // stepLocked resolves rp.Component() to an existing file, starting from the // given directory. // @@ -229,7 +280,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir return nil, err } if child != nil { - if !file.isNil() && inoFromPath(qid.Path) == child.ino { + if !file.isNil() && qid.Path == child.qidPath { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) child.updateFromP9AttrsLocked(attrMask, &attr) @@ -256,7 +307,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // treat their invalidation as deletion. child.setDeleted() parent.syntheticChildren-- - child.decRefLocked() + child.decRefNoCaching() parent.dirents = nil } *ds = appendDentry(*ds, child) @@ -366,9 +417,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if len(name) > maxFilenameLen { return syserror.ENAMETOOLONG } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.isDeleted() { return syserror.ENOENT } @@ -383,6 +431,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } if createInSyntheticDir == nil { return syserror.EPERM } @@ -402,6 +453,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil && child.isSynthetic() { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // The existence of a non-synthetic dentry at name would be inconclusive // because the file it represents may have been deleted from the remote // filesystem, so we would need to make an RPC to revalidate the dentry. @@ -422,6 +476,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // No cached dentry exists; however, there might still be an existing file // at name. As above, we attempt the file creation RPC anyway. if err := createInRemoteDir(parent, name, &ds); err != nil { @@ -625,7 +682,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child.setDeleted() if child.isSynthetic() { parent.syntheticChildren-- - child.decRefLocked() + child.decRefNoCaching() } ds = appendDentry(ds, child) } @@ -640,41 +697,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return nil } -// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls -// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for -// writing. -// -// ds is a pointer-to-pointer since defer evaluates its arguments immediately, -// but dentry slices are allocated lazily, and it's much easier to say "defer -// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() { -// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { - fs.renameMu.RUnlock() - if *ds == nil { - return - } - if len(**ds) != 0 { - fs.renameMu.Lock() - for _, d := range **ds { - d.checkCachingLocked(ctx) - } - fs.renameMu.Unlock() - } - putDentrySlice(*ds) -} - -func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) { - if *ds == nil { - fs.renameMu.Unlock() - return - } - for _, d := range **ds { - d.checkCachingLocked(ctx) - } - fs.renameMu.Unlock() - putDentrySlice(*ds) -} - // AccessAt implements vfs.Filesystem.Impl.AccessAt. func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error { var ds *[]*dentry @@ -836,7 +858,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v mode: opts.Mode, kuid: creds.EffectiveKUID, kgid: creds.EffectiveKGID, - pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) return nil } @@ -964,14 +986,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open switch d.fileType() { case linux.S_IFREG: if !d.fs.opts.regularFilesUseSpecialFileFD { - if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil { + if err := d.ensureSharedHandle(ctx, ats.MayRead(), ats.MayWrite(), trunc); err != nil { return nil, err } - fd := ®ularFileFD{} - fd.LockFD.Init(&d.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ - AllowDirectIO: true, - }); err != nil { + fd, err := newRegularFileFD(mnt, d, opts.Flags) + if err != nil { return nil, err } vfd = &fd.vfsfd @@ -998,6 +1017,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } + if atomic.LoadInt32(&d.readFD) >= 0 { + fsmetric.GoferOpensHost.Increment() + } else { + fsmetric.GoferOpens9P.Increment() + } return &fd.vfsfd, nil case linux.S_IFLNK: // Can't open symlinks without O_PATH (which is unimplemented). @@ -1089,7 +1113,7 @@ retry: return nil, err } } - fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags) + fd, err := newSpecialFileFD(h, mnt, d, opts.Flags) if err != nil { h.close(ctx) return nil, err @@ -1156,18 +1180,21 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { + openFD := int32(-1) + if fdobj != nil { + openFD = int32(fdobj.Release()) + } child.handleMu.Lock() if vfs.MayReadFileWithOpenFlags(opts.Flags) { child.readFile = openFile if fdobj != nil { - child.hostFD = int32(fdobj.Release()) + child.readFD = openFD + child.mmapFD = openFD } - } else if fdobj != nil { - // Can't use fdobj if it's not readable. - fdobj.Close() } if vfs.MayWriteFileWithOpenFlags(opts.Flags) { child.writeFile = openFile + child.writeFD = openFD } child.handleMu.Unlock() } @@ -1181,11 +1208,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving // Finally, construct a file description representing the created file. var childVFSFD *vfs.FileDescription if useRegularFileFD { - fd := ®ularFileFD{} - fd.LockFD.Init(&child.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{ - AllowDirectIO: true, - }); err != nil { + fd, err := newRegularFileFD(mnt, child, opts.Flags) + if err != nil { return nil, err } childVFSFD = &fd.vfsfd @@ -1197,7 +1221,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving if fdobj != nil { h.fd = int32(fdobj.Release()) } - fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags) + fd, err := newSpecialFileFD(h, mnt, child, opts.Flags) if err != nil { h.close(ctx) return nil, err @@ -1355,7 +1379,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa replaced.setDeleted() if replaced.isSynthetic() { newParent.syntheticChildren-- - replaced.decRefLocked() + replaced.decRefNoCaching() } ds = appendDentry(ds, replaced) } @@ -1364,7 +1388,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // with reference counts and queue oldParent for checkCachingLocked if the // parent isn't actually changing. if oldParent != newParent { - oldParent.decRefLocked() + oldParent.decRefNoCaching() ds = appendDentry(ds, oldParent) newParent.IncRef() if renamed.isSynthetic() { @@ -1512,7 +1536,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath d.IncRef() return &endpoint{ dentry: d, - file: d.file.file, path: opts.Addr, }, nil } @@ -1591,7 +1614,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } - -func (fs *filesystem) nextSyntheticIno() inodeNumber { - return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) -} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index f1dad1b08..3cdb1e659 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -26,6 +26,9 @@ // *** "memmap.Mappable locks taken by Translate" below this point // dentry.handleMu // dentry.dataMu +// filesystem.inoMu +// specialFileFD.mu +// specialFileFD.bufMu // // Locking dentry.dirMu in multiple dentries requires that either ancestor // dentries are locked before descendant dentries, or that filesystem.renameMu @@ -36,7 +39,6 @@ import ( "fmt" "strconv" "strings" - "sync" "sync/atomic" "syscall" @@ -44,6 +46,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -53,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -81,7 +86,7 @@ type filesystem struct { iopts InternalFilesystemOptions // client is the client used by this filesystem. client is immutable. - client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + client *p9.Client `state:"nosave"` // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -89,6 +94,9 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 + // root is the root dentry. root is immutable. + root *dentry + // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -103,39 +111,35 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) - // cachedDentriesLen is the number of dentries in cachedDentries. These - // fields are protected by renameMu. + // cachedDentriesLen is the number of dentries in cachedDentries. These fields + // are protected by renameMu. cachedDentries dentryList cachedDentriesLen uint64 - // syncableDentries contains all dentries in this filesystem for which - // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs. - // These fields are protected by syncMu. + // syncableDentries contains all non-synthetic dentries. specialFileFDs + // contains all open specialFileFDs. These fields are protected by syncMu. syncMu sync.Mutex `state:"nosave"` syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} - // syntheticSeq stores a counter to used to generate unique inodeNumber for - // synthetic dentries. - syntheticSeq uint64 -} + // inoByQIDPath maps previously-observed QID.Paths to inode numbers + // assigned to those paths. inoByQIDPath is not preserved across + // checkpoint/restore because QIDs may be reused between different gofer + // processes, so QIDs may be repeated for different files across + // checkpoint/restore. inoByQIDPath is protected by inoMu. + inoMu sync.Mutex `state:"nosave"` + inoByQIDPath map[uint64]uint64 `state:"nosave"` -// inodeNumber represents inode number reported in Dirent.Ino. For regular -// dentries, it comes from QID.Path from the 9P server. Synthetic dentries -// have have their inodeNumber generated sequentially, with the MSB reserved to -// prevent conflicts with regular dentries. -// -// +stateify savable -type inodeNumber uint64 + // lastIno is the last inode number assigned to a file. lastIno is accessed + // using atomic memory operations. + lastIno uint64 -// Reserve MSB for synthetic mounts. -const syntheticInoMask = uint64(1) << 63 + // savedDentryRW records open read/write handles during save/restore. + savedDentryRW map[*dentry]savedDentryRW -func inoFromPath(path uint64) inodeNumber { - if path&syntheticInoMask != 0 { - log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) - } - return inodeNumber(path &^ syntheticInoMask) + // released is nonzero once filesystem.Release has been called. It is accessed + // with atomic memory operations. + released int32 } // +stateify savable @@ -149,8 +153,7 @@ type filesystemOptions struct { msize uint32 version string - // maxCachedDentries is the maximum number of dentries with 0 references - // retained by the client. + // maxCachedDentries is the maximum size of filesystem.cachedDentries. maxCachedDentries uint64 // If forcePageCache is true, host FDs may not be used for application @@ -247,6 +250,10 @@ const ( // // +stateify savable type InternalFilesystemOptions struct { + // If UniqueID is non-empty, it is an opaque string used to reassociate the + // filesystem with a new server FD during restoration from checkpoint. + UniqueID string + // If LeakConnection is true, do not close the connection to the server // when the Filesystem is released. This is necessary for deployments in // which servers can handle only a single client and report failure if that @@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mopts := vfs.GenericParseMountOptions(opts.Data) var fsopts filesystemOptions - // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'") - return nil, nil, syserror.EINVAL - } - delete(mopts, "trans") - if trans != "fd" { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans) - return nil, nil, syserror.EINVAL - } - - // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "rfdno") - rfd, err := strconv.Atoi(rfdstr) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr) - return nil, nil, syserror.EINVAL - } - wfdstr, ok := mopts["wfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "wfdno") - wfd, err := strconv.Atoi(wfdstr) + fd, err := getFDFromMountOptionsMap(ctx, mopts) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr) - return nil, nil, syserror.EINVAL - } - if rfd != wfd { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd) - return nil, nil, syserror.EINVAL + return nil, nil, err } - fsopts.fd = rfd + fsopts.fd = fd // Get the attach name. fsopts.aname = "/" @@ -441,56 +413,43 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // If !ok, iopts being the zero value is correct. - // Establish a connection with the server. - conn, err := unet.NewSocket(fsopts.fd) + // Construct the filesystem object. + devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } + fs := &filesystem{ + mfp: mfp, + opts: fsopts, + iopts: iopts, + clock: ktime.RealtimeClockFromContext(ctx), + devMinor: devMinor, + syncableDentries: make(map[*dentry]struct{}), + specialFileFDs: make(map[*specialFileFD]struct{}), + inoByQIDPath: make(map[uint64]uint64), + } + fs.vfsfs.Init(vfsObj, &fstype, fs) - // Perform version negotiation with the server. - ctx.UninterruptibleSleepStart(false) - client, err := p9.NewClient(conn, fsopts.msize, fsopts.version) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - conn.Close() + // Connect to the server. + if err := fs.dial(ctx); err != nil { return nil, nil, err } - // Ownership of conn has been transferred to client. // Perform attach to obtain the filesystem root. ctx.UninterruptibleSleepStart(false) - attached, err := client.Attach(fsopts.aname) + attached, err := fs.client.Attach(fsopts.aname) ctx.UninterruptibleSleepFinish(false) if err != nil { - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } attachFile := p9file{attached} qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) if err != nil { attachFile.close(ctx) - client.Close() - return nil, nil, err - } - - // Construct the filesystem object. - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - attachFile.close(ctx) - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } - fs := &filesystem{ - mfp: mfp, - opts: fsopts, - iopts: iopts, - client: client, - clock: ktime.RealtimeClockFromContext(ctx), - devMinor: devMinor, - syncableDentries: make(map[*dentry]struct{}), - specialFileFDs: make(map[*specialFileFD]struct{}), - } - fs.vfsfs.Init(vfsObj, &fstype, fs) // Construct the root dentry. root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) @@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is deliberately leaked to prevent the root from - // being "cached" and subsequently evicted. Its resources will still be - // cleaned up by fs.Release(). + // caller, and the other is held by fs to prevent the root from being "cached" + // and subsequently evicted. root.refs = 2 + fs.root = root return &fs.vfsfs, &root.vfsd, nil } +func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { + // Check that the transport is "fd". + trans, ok := mopts["trans"] + if !ok || trans != "fd" { + ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'") + return -1, syserror.EINVAL + } + delete(mopts, "trans") + + // Check that read and write FDs are provided and identical. + rfdstr, ok := mopts["rfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "rfdno") + rfd, err := strconv.Atoi(rfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr) + return -1, syserror.EINVAL + } + wfdstr, ok := mopts["wfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "wfdno") + wfd, err := strconv.Atoi(wfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr) + return -1, syserror.EINVAL + } + if rfd != wfd { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd) + return -1, syserror.EINVAL + } + return rfd, nil +} + +// Preconditions: fs.client == nil. +func (fs *filesystem) dial(ctx context.Context) error { + // Establish a connection with the server. + conn, err := unet.NewSocket(fs.opts.fd) + if err != nil { + return err + } + + // Perform version negotiation with the server. + ctx.UninterruptibleSleepStart(false) + client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + conn.Close() + return err + } + // Ownership of conn has been transferred to client. + + fs.client = client + return nil +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { - mf := fs.mfp.MemoryFile() + atomic.StoreInt32(&fs.released, 1) + mf := fs.mfp.MemoryFile() fs.syncMu.Lock() for d := range fs.syncableDentries { d.handleMu.Lock() d.dataMu.Lock() if h := d.writeHandleLocked(); h.isOpen() { // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) } // TODO(jamieliu): Do we need to flushf/fsync d? @@ -527,11 +548,16 @@ func (fs *filesystem) Release(ctx context.Context) { d.cache.DropAll(mf) d.dirty.RemoveAll() d.dataMu.Unlock() - // Close the host fd if one exists. - if d.hostFD >= 0 { - syscall.Close(int(d.hostFD)) - d.hostFD = -1 + // Close host FDs if they exist. + if d.readFD >= 0 { + syscall.Close(int(d.readFD)) } + if d.writeFD >= 0 && d.readFD != d.writeFD { + syscall.Close(int(d.writeFD)) + } + d.readFD = -1 + d.writeFD = -1 + d.mmapFD = -1 d.handleMu.Unlock() } // There can't be any specialFileFDs still using fs, since each such @@ -539,6 +565,21 @@ func (fs *filesystem) Release(ctx context.Context) { // fs. fs.syncMu.Unlock() + // If leak checking is enabled, release all outstanding references in the + // filesystem. We deliberately avoid doing this outside of leak checking; we + // have released all external resources above rather than relying on dentry + // destructors. + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + fs.renameMu.Lock() + fs.root.releaseSyntheticRecursiveLocked(ctx) + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // An extra reference was held by the filesystem on the root to prevent it from + // being cached/evicted. + fs.root.DecRef(ctx) + } + if !fs.iopts.LeakConnection { // Close the connection to the server. This implicitly clunks all fids. fs.client.Close() @@ -547,6 +588,31 @@ func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } +// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements +// the reference count on every synthetic dentry. Synthetic dentries have one +// reference for existence that should be dropped during filesystem.Release. +// +// Precondition: d.fs.renameMu is locked. +func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { + if d.isSynthetic() { + d.decRefNoCaching() + d.checkCachingLocked(ctx) + } + if d.isDir() { + var children []*dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + if child != nil { + child.releaseSyntheticRecursiveLocked(ctx) + } + } + } +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -574,12 +640,15 @@ type dentry struct { // filesystem.renameMu. name string + // qidPath is the p9.QID.Path for this file. qidPath is immutable. + qidPath uint64 + // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file // that does not exist on the remote filesystem. As of this writing, the // only files that can be synthetic are sockets, pipes, and directories. - file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + file p9file `state:"nosave"` // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. @@ -623,12 +692,12 @@ type dentry struct { // To mutate: // - Lock metadataMu and use atomic operations to update because we might // have atomic readers that don't hold the lock. - metadataMu sync.Mutex `state:"nosave"` - ino inodeNumber // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + metadataMu sync.Mutex `state:"nosave"` + ino uint64 // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 @@ -662,26 +731,37 @@ type dentry struct { // - If this dentry represents a regular file or directory, readFile is the // p9.File used for reads by all regularFileFDs/directoryFDs representing - // this dentry. + // this dentry, and readFD (if not -1) is a host FD equivalent to readFile + // used as a faster alternative. // // - If this dentry represents a regular file, writeFile is the p9.File - // used for writes by all regularFileFDs representing this dentry. + // used for writes by all regularFileFDs representing this dentry, and + // writeFD (if not -1) is a host FD equivalent to writeFile used as a + // faster alternative. // - // - If this dentry represents a regular file, hostFD is the host FD used - // for memory mappings and I/O (when applicable) in preference to readFile - // and writeFile. hostFD is always readable; if !writeFile.isNil(), it must - // also be writable. If hostFD is -1, no such host FD is available. + // - If this dentry represents a regular file, mmapFD is the host FD used + // for memory mappings. If mmapFD is -1, no such FD is available, and the + // internal page cache implementation is used for memory mappings instead. // - // These fields are protected by handleMu. + // These fields are protected by handleMu. readFD, writeFD, and mmapFD are + // additionally written using atomic memory operations, allowing them to be + // read (albeit racily) with atomic.LoadInt32() without locking handleMu. // // readFile and writeFile may or may not represent the same p9.File. Once // either p9.File transitions from closed (isNil() == true) to open // (isNil() == false), it may be mutated with handleMu locked, but cannot // be closed until the dentry is destroyed. + // + // readFD and writeFD may or may not be the same file descriptor. mmapFD is + // always either -1 or equal to readFD; if !writeFile.isNil() (the file has + // been opened for writing), it is additionally either -1 or equal to + // writeFD. handleMu sync.RWMutex `state:"nosave"` - readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - hostFD int32 + readFile p9file `state:"nosave"` + writeFile p9file `state:"nosave"` + readFD int32 `state:"nosave"` + writeFD int32 `state:"nosave"` + mmapFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -758,13 +838,16 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, + qidPath: qid.Path, file: file, - ino: inoFromPath(qid.Path), + ino: fs.inoFromQIDPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), blockSize: usermem.PageSize, - hostFD: -1, + readFD: -1, + writeFD: -1, + mmapFD: -1, } d.pf.dentry = d if mask.UID { @@ -795,13 +878,28 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d.nlink = uint32(attr.NLink) } d.vfsd.Init(d) - + refsvfs2.Register(d) fs.syncMu.Lock() fs.syncableDentries[d] = struct{}{} fs.syncMu.Unlock() return d, nil } +func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 { + fs.inoMu.Lock() + defer fs.inoMu.Unlock() + if ino, ok := fs.inoByQIDPath[qidPath]; ok { + return ino + } + ino := fs.nextIno() + fs.inoByQIDPath[qidPath] = ino + return ino +} + +func (fs *filesystem) nextIno() uint64 { + return atomic.AddUint64(&fs.lastIno, 1) +} + func (d *dentry) isSynthetic() bool { return d.file.isNil() } @@ -853,7 +951,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } } -// Preconditions: !d.isSynthetic() +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { // Use d.readFile or d.writeFile, which represent 9P fids that have been // opened, in preference to d.file, which represents a 9P fid that has not. @@ -916,10 +1014,10 @@ func (d *dentry) statTo(stat *linux.Statx) { // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. stat.Blocks = (stat.Size + 511) / 512 - stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime)) - stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime)) - stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime)) - stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime)) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime)) + stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime)) stat.DevMajor = linux.UNNAMED_MAJOR stat.DevMinor = d.fs.devMinor } @@ -967,10 +1065,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // Use client clocks for timestamps. now = d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW { - stat.Atime = statxTimestampFromDentry(now) + stat.Atime = linux.NsecToStatxTimestamp(now) } if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW { - stat.Mtime = statxTimestampFromDentry(now) + stat.Mtime = linux.NsecToStatxTimestamp(now) } } @@ -1029,11 +1127,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // !d.cachedMetadataAuthoritative() then we returned after calling // d.file.setAttr(). For the same reason, now must have been initialized. if stat.Mask&linux.STATX_ATIME != 0 { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) + atomic.StoreInt64(&d.atime, stat.Atime.ToNsec()) atomic.StoreUint32(&d.atimeDirty, 0) } if stat.Mask&linux.STATX_MTIME != 0 { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) + atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec()) atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) @@ -1139,17 +1237,23 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 { func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkCachingLocked(). - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } return true } } @@ -1157,22 +1261,43 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + if d.decRefNoCaching() == 0 { d.fs.renameMu.Lock() d.checkCachingLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") } } -// decRefLocked decrements d's reference count without calling +// decRefNoCaching decrements d's reference count without calling // d.checkCachingLocked, even if d's reference count reaches 0; callers are // responsible for ensuring that d.checkCachingLocked will be called later. -func (d *dentry) decRefLocked() { - if refs := atomic.AddInt64(&d.refs, -1); refs < 0 { - panic("gofer.dentry.decRefLocked() called without holding a reference") +func (d *dentry) decRefNoCaching() int64 { + r := atomic.AddInt64(&d.refs, -1) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } + if r < 0 { + panic("gofer.dentry.decRefNoCaching() called without holding a reference") } + return r +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "gofer.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -1223,40 +1348,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // resolution, which requires renameMu, so if d.refs is zero then it will // remain zero while we hold renameMu for writing.) refs := atomic.LoadInt64(&d.refs) - if refs > 0 { - if d.cached { - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } - return - } if refs == -1 { // Dentry has already been destroyed. return } + if refs > 0 { + // This isn't strictly necessary (fs.cachedDentries is permitted to + // contain dentries with non-zero refs, which are skipped by + // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but + // since we are already holding fs.renameMu for writing we may as well. + d.removeFromCacheLocked() + return + } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { if d.isDeleted() { d.watches.HandleDeletion(ctx) } - if d.cached { - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } + d.removeFromCacheLocked() d.destroyLocked(ctx) return } - // If d still has inotify watches and it is not deleted or invalidated, we - // cannot cache it and allow it to be evicted. Otherwise, we will lose its - // watches, even if a new dentry is created for the same file in the future. - // Note that the size of d.watches cannot concurrently transition from zero - // to non-zero, because adding a watch requires holding a reference on d. + // If d still has inotify watches and it is not deleted or invalidated, it + // can't be evicted. Otherwise, we will lose its watches, even if a new + // dentry is created for the same file in the future. Note that the size of + // d.watches cannot concurrently transition from zero to non-zero, because + // adding a watch requires holding a reference on d. if d.watches.Size() > 0 { + // As in the refs > 0 case, this is not strictly necessary. + d.removeFromCacheLocked() return } + + if atomic.LoadInt32(&d.fs.released) != 0 { + if d.parent != nil { + d.parent.dirMu.Lock() + delete(d.parent.children, d.name) + d.parent.dirMu.Unlock() + } + d.destroyLocked(ctx) + } + // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) @@ -1269,30 +1402,52 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentriesLen++ d.cached = true if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { - victim := d.fs.cachedDentries.Back() - d.fs.cachedDentries.Remove(victim) + d.fs.evictCachedDentryLocked(ctx) + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.opts.maxCachedDentries, so we don't loop. + } +} + +// Preconditions: d.fs.renameMu must be locked for writing. +func (d *dentry) removeFromCacheLocked() { + if d.cached { + d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path resolution - // since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() + d.cached = false + } +} + +// Precondition: fs.renameMu must be locked for writing; it may be temporarily +// unlocked. +func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } +} + +// Preconditions: +// * fs.renameMu must be locked for writing; it may be temporarily unlocked. +// * fs.cachedDentriesLen != 0. +func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + victim := fs.cachedDentries.Back() + victim.removeFromCacheLocked() + // victim.refs or victim.watches.Size() may have become non-zero from an + // earlier path resolution since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 { + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. } - victim.destroyLocked(ctx) + victim.parent.dirMu.Unlock() } - // Whether or not victim was destroyed, we brought fs.cachedDentriesLen - // back down to fs.opts.maxCachedDentries, so we don't loop. + victim.destroyLocked(ctx) } } @@ -1343,10 +1498,15 @@ func (d *dentry) destroyLocked(ctx context.Context) { } d.readFile = p9file{} d.writeFile = p9file{} - if d.hostFD >= 0 { - syscall.Close(int(d.hostFD)) - d.hostFD = -1 + if d.readFD >= 0 { + syscall.Close(int(d.readFD)) } + if d.writeFD >= 0 && d.readFD != d.writeFD { + syscall.Close(int(d.writeFD)) + } + d.readFD = -1 + d.writeFD = -1 + d.mmapFD = -1 d.handleMu.Unlock() if !d.file.isNil() { @@ -1373,13 +1533,10 @@ func (d *dentry) destroyLocked(ctx context.Context) { // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. - if d.parent != nil { - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkCachingLocked(ctx) - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") - } + if d.parent != nil && d.parent.decRefNoCaching() == 0 { + d.parent.checkCachingLocked(ctx) } + refsvfs2.Unregister(d) } func (d *dentry) isDeleted() bool { @@ -1461,7 +1618,8 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool d.handleMu.RUnlock() } - fdToClose := int32(-1) + var fdsToCloseArr [2]int32 + fdsToClose := fdsToCloseArr[:0] invalidateTranslations := false d.handleMu.Lock() if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc { @@ -1492,56 +1650,88 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool return err } - if d.hostFD < 0 && h.fd >= 0 && openReadable && (d.writeFile.isNil() || openWritable) { - // We have no existing FD, and the new FD meets the requirements - // for d.hostFD, so start using it. - d.hostFD = h.fd - } else if d.hostFD >= 0 && d.writeFile.isNil() && openWritable { - // We have an existing read-only FD, but the file has just been - // opened for writing, so we need to start supporting writable memory - // mappings. This may race with callers of d.pf.FD() using the existing - // FD, so in most cases we need to delay closing the old FD until after - // invalidating memmap.Translations that might have observed it. - if !openReadable || h.fd < 0 { - // We don't have a read/write FD, so we have no FD that can be - // used to create writable memory mappings. Switch to using the - // internal page cache. - invalidateTranslations = true - fdToClose = d.hostFD - d.hostFD = -1 - } else if d.fs.opts.overlayfsStaleRead { - // We do have a read/write FD, but it may not be coherent with - // the existing read-only FD, so we must switch to mappings of - // the new FD in both the application and sentry. - if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) - h.close(ctx) - return err + // Update d.readFD and d.writeFD. + if h.fd >= 0 { + if openReadable && openWritable && (d.readFD < 0 || d.writeFD < 0 || d.readFD != d.writeFD) { + // Replace existing FDs with this one. + if d.readFD >= 0 { + // We already have a readable FD that may be in use by + // concurrent callers of d.pf.FD(). + if d.fs.opts.overlayfsStaleRead { + // If overlayfsStaleRead is in effect, then the new FD + // may not be coherent with the existing one, so we + // have no choice but to switch to mappings of the new + // FD in both the application and sentry. + if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { + d.handleMu.Unlock() + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) + h.close(ctx) + return err + } + fdsToClose = append(fdsToClose, d.readFD) + invalidateTranslations = true + atomic.StoreInt32(&d.readFD, h.fd) + } else { + // Otherwise, we want to avoid invalidating existing + // memmap.Translations (which is expensive); instead, use + // dup3 to make the old file descriptor refer to the new + // file description, then close the new file descriptor + // (which is no longer needed). Racing callers of d.pf.FD() + // may use the old or new file description, but this + // doesn't matter since they refer to the same file, and + // any racing mappings must be read-only. + if err := syscall.Dup3(int(h.fd), int(d.readFD), syscall.O_CLOEXEC); err != nil { + oldFD := d.readFD + d.handleMu.Unlock() + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldFD, err) + h.close(ctx) + return err + } + fdsToClose = append(fdsToClose, h.fd) + h.fd = d.readFD + } + } else { + atomic.StoreInt32(&d.readFD, h.fd) } - invalidateTranslations = true - fdToClose = d.hostFD - d.hostFD = h.fd - } else { - // We do have a read/write FD. To avoid invalidating existing - // memmap.Translations (which is expensive), use dup3 to make - // the old file descriptor refer to the new file description, - // then close the new file descriptor (which is no longer - // needed). Racing callers of d.pf.FD() may use the old or new - // file description, but this doesn't matter since they refer - // to the same file, and any racing mappings must be read-only. - if err := syscall.Dup3(int(h.fd), int(d.hostFD), syscall.O_CLOEXEC); err != nil { - oldHostFD := d.hostFD - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldHostFD, err) - h.close(ctx) - return err + if d.writeFD != h.fd && d.writeFD >= 0 { + fdsToClose = append(fdsToClose, d.writeFD) + } + atomic.StoreInt32(&d.writeFD, h.fd) + atomic.StoreInt32(&d.mmapFD, h.fd) + } else if openReadable && d.readFD < 0 { + atomic.StoreInt32(&d.readFD, h.fd) + // If the file has not been opened for writing, the new FD may + // be used for read-only memory mappings. If the file was + // previously opened for reading (without an FD), then existing + // translations of the file may use the internal page cache; + // invalidate those mappings. + if d.writeFile.isNil() { + invalidateTranslations = !d.readFile.isNil() + atomic.StoreInt32(&d.mmapFD, h.fd) + } + } else if openWritable && d.writeFD < 0 { + atomic.StoreInt32(&d.writeFD, h.fd) + if d.readFD >= 0 { + // We have an existing read-only FD, but the file has just + // been opened for writing, so we need to start supporting + // writable memory mappings. However, the new FD is not + // readable, so we have no FD that can be used to create + // writable memory mappings. Switch to using the internal + // page cache. + invalidateTranslations = true + atomic.StoreInt32(&d.mmapFD, -1) } - fdToClose = h.fd + } else { + // The new FD is not useful. + fdsToClose = append(fdsToClose, h.fd) } - } else { - // h.fd is not useful. - fdToClose = h.fd + } else if openWritable && d.writeFD < 0 && d.mmapFD >= 0 { + // We have an existing read-only FD, but the file has just been + // opened for writing, so we need to start supporting writable + // memory mappings. However, we have no writable host FD. Switch to + // using the internal page cache. + invalidateTranslations = true + atomic.StoreInt32(&d.mmapFD, -1) } // Switch to new fids. @@ -1575,8 +1765,8 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool d.mappings.InvalidateAll(memmap.InvalidateOpts{}) d.mapsMu.Unlock() } - if fdToClose >= 0 { - syscall.Close(int(fdToClose)) + for _, fd := range fdsToClose { + syscall.Close(int(fd)) } return nil @@ -1586,7 +1776,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool func (d *dentry) readHandleLocked() handle { return handle{ file: d.readFile, - fd: d.hostFD, + fd: d.readFD, } } @@ -1594,7 +1784,7 @@ func (d *dentry) readHandleLocked() handle { func (d *dentry) writeHandleLocked() handle { return handle{ file: d.writeFile, - fd: d.hostFD, + fd: d.writeFD, } } @@ -1607,22 +1797,57 @@ func (d *dentry) syncRemoteFile(ctx context.Context) error { // Preconditions: d.handleMu must be locked. func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { // If we have a host FD, fsyncing it is likely to be faster than an fsync - // RPC. - if d.hostFD >= 0 { + // RPC. Prefer syncing write handles over read handles, since some remote + // filesystem implementations may not sync changes made through write + // handles otherwise. + if d.writeFD >= 0 { ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(d.hostFD)) + err := syscall.Fsync(int(d.writeFD)) ctx.UninterruptibleSleepFinish(false) return err } if !d.writeFile.isNil() { return d.writeFile.fsync(ctx) } + if d.readFD >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(d.readFD)) + ctx.UninterruptibleSleepFinish(false) + return err + } if !d.readFile.isNil() { return d.readFile.fsync(ctx) } return nil } +func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() + if h.isOpen() { + // Write back dirty pages to the remote file. + d.dataMu.Lock() + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } + } + if err := d.syncRemoteFileLocked(ctx); err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (d is a regular file and was opened for writing). + if d.isRegularFile() && h.isOpen() { + return err + } + ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err) + } + return nil +} + // incLinks increments link count. func (d *dentry) incLinks() { if atomic.LoadUint32(&d.nlink) == 0 { @@ -1650,7 +1875,7 @@ type fileDescription struct { vfs.FileDescriptionDefaultImpl vfs.LockFD - lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + lockLogging sync.Once `state:"nosave"` } func (fd *fileDescription) filesystem() *filesystem { diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index bfe75dfe4..76f08e252 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -26,12 +26,13 @@ import ( func TestDestroyIdempotent(t *testing.T) { ctx := contexttest.Context(t) fs := filesystem{ - mfp: pgalloc.MemoryFileProviderFromContext(ctx), - syncableDentries: make(map[*dentry]struct{}), + mfp: pgalloc.MemoryFileProviderFromContext(ctx), opts: filesystemOptions{ // Test relies on no dentry being held in the cache. maxCachedDentries: 0, }, + syncableDentries: make(map[*dentry]struct{}), + inoByQIDPath: make(map[uint64]uint64), } attr := &p9.Attr{ diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index 7294de7d6..c7bf10007 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error { if ok { return nil } - if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { - return err + if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil { + // Another application thread may have opened this pipe for + // writing, succeeded because we previously opened the pipe for + // reading, and subsequently interrupted us for checkpointing (e.g. + // this occurs in mknod tests under cooperative save/restore). In + // this case, our open has to succeed for the checkpoint to include + // a readable FD for the pipe, which is in turn necessary to + // restore the other thread's writable FD for the same pipe + // (otherwise it will get ENXIO). So we have to check + // nonblockingPipeHasWriter() once last time. + ok, err := nonblockingPipeHasWriter(fd) + if err != nil { + return err + } + if ok { + return nil + } + return sleepErr } } } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index f8b19bae7..283b220bb 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -18,7 +18,6 @@ import ( "fmt" "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,10 +26,12 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -48,6 +49,25 @@ type regularFileFD struct { off int64 } +func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, error) { + fd := ®ularFileFD{} + fd.LockFD.Init(&d.locks) + if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ + AllowDirectIO: true, + }); err != nil { + return nil, err + } + if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { + fsmetric.GoferOpensWX.Increment() + } + if atomic.LoadInt32(&d.mmapFD) >= 0 { + fsmetric.GoferOpensHost.Increment() + } else { + fsmetric.GoferOpens9P.Increment() + } + return fd, nil +} + // Release implements vfs.FileDescriptionImpl.Release. func (fd *regularFileFD) Release(context.Context) { } @@ -89,6 +109,18 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + start := fsmetric.StartReadWait() + d := fd.dentry() + defer func() { + if atomic.LoadInt32(&d.readFD) >= 0 { + fsmetric.GoferReadsHost.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start) + } else { + fsmetric.GoferReads9P.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start) + } + }() + if offset < 0 { return 0, syserror.EINVAL } @@ -102,7 +134,6 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // Check for reading at EOF before calling into MM (but not under // InteropModeShared, which makes d.size unreliable). - d := fd.dentry() if d.cachedMetadataAuthoritative() && uint64(offset) >= atomic.LoadUint64(&d.size) { return 0, io.EOF } @@ -326,7 +357,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) // dentry.readHandleLocked() without locking dentry.dataMu. rw.d.handleMu.RLock() h := rw.d.readHandleLocked() - if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + if (rw.d.mmapFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { n, err := h.readToBlocksAt(rw.ctx, dsts, rw.off) rw.d.handleMu.RUnlock() rw.off += n @@ -446,7 +477,7 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro // without locking dentry.dataMu. rw.d.handleMu.RLock() h := rw.d.writeHandleLocked() - if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + if (rw.d.mmapFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { n, err := h.writeFromBlocksAt(rw.ctx, srcs, rw.off) rw.off += n rw.d.dataMu.Lock() @@ -624,23 +655,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncCachedFile(ctx) -} - -func (d *dentry) syncCachedFile(ctx context.Context) error { - d.handleMu.RLock() - defer d.handleMu.RUnlock() - - if h := d.writeHandleLocked(); h.isOpen() { - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) - d.dataMu.Unlock() - if err != nil { - return err - } - } - return d.syncRemoteFileLocked(ctx) + return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -663,10 +678,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt // Whether or not we have a host FD, we're not allowed to use it. return syserror.ENODEV } - d.handleMu.RLock() - haveFD := d.hostFD >= 0 - d.handleMu.RUnlock() - if !haveFD { + if atomic.LoadInt32(&d.mmapFD) < 0 { return syserror.ENODEV } default: @@ -684,10 +696,7 @@ func (d *dentry) mayCachePages() bool { if d.fs.opts.forcePageCache { return true } - d.handleMu.RLock() - haveFD := d.hostFD >= 0 - d.handleMu.RUnlock() - return haveFD + return atomic.LoadInt32(&d.mmapFD) >= 0 } // AddMapping implements memmap.Mappable.AddMapping. @@ -743,7 +752,7 @@ func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, // Translate implements memmap.Mappable.Translate. func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { d.handleMu.RLock() - if d.hostFD >= 0 && !d.fs.opts.forcePageCache { + if d.mmapFD >= 0 && !d.fs.opts.forcePageCache { d.handleMu.RUnlock() mr := optional if d.fs.opts.limitHostFDTranslation { @@ -897,7 +906,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { // cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file -// is available (i.e. dentry.hostFD >= 0), and that FD is used for application +// is available (i.e. dentry.mmapFD >= 0), and that FD is used for application // memory mappings (i.e. !filesystem.opts.forcePageCache). // // +stateify savable @@ -908,12 +917,12 @@ type dentryPlatformFile struct { // by dentry.dataMu. fdRefs fsutil.FrameRefSet - // If this dentry represents a regular file, and dentry.hostFD >= 0, - // hostFileMapper caches mappings of dentry.hostFD. + // If this dentry represents a regular file, and dentry.mmapFD >= 0, + // hostFileMapper caches mappings of dentry.mmapFD. hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + hostFileMapperInitOnce sync.Once `state:"nosave"` } // IncRef implements memmap.File.IncRef. @@ -934,12 +943,12 @@ func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() defer d.handleMu.RUnlock() - return d.hostFileMapper.MapInternal(fr, int(d.hostFD), at.Write) + return d.hostFileMapper.MapInternal(fr, int(d.mmapFD), at.Write) } // FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() defer d.handleMu.RUnlock() - return int(d.hostFD) + return int(d.mmapFD) } diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go new file mode 100644 index 000000000..c90071e4e --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -0,0 +1,331 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "fmt" + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type saveRestoreContextID int + +const ( + // CtxRestoreServerFDMap is a Context.Value key for a map[string]int + // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID) + // to host FDs. + CtxRestoreServerFDMap saveRestoreContextID = iota +) + +// +stateify savable +type savedDentryRW struct { + read bool + write bool +} + +// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave. +func (fs *filesystem) PrepareSave(ctx context.Context) error { + if len(fs.iopts.UniqueID) == 0 { + return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved") + } + + // Purge cached dentries, which may not be reopenable after restore due to + // permission changes. + fs.renameMu.Lock() + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // Buffer pipe data so that it's available for reading after restore. (This + // is a legacy VFS1 feature.) + fs.syncMu.Lock() + for sffd := range fs.specialFileFDs { + if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() { + if err := sffd.savePipeData(ctx); err != nil { + fs.syncMu.Unlock() + return err + } + } + } + fs.syncMu.Unlock() + + // Flush local state to the remote filesystem. + if err := fs.Sync(ctx); err != nil { + return err + } + + fs.savedDentryRW = make(map[*dentry]savedDentryRW) + return fs.root.prepareSaveRecursive(ctx) +} + +// Preconditions: +// * fd represents a pipe. +// * fd is readable. +func (fd *specialFileFD) savePipeData(ctx context.Context) error { + fd.bufMu.Lock() + defer fd.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0)) + if n != 0 { + fd.buf = append(fd.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syserror.EAGAIN { + break + } + return err + } + } + if len(fd.buf) != 0 { + atomic.StoreUint32(&fd.haveBuf, 1) + } + return nil +} + +func (d *dentry) prepareSaveRecursive(ctx context.Context) error { + if d.isRegularFile() && !d.cachedMetadataAuthoritative() { + // Get updated metadata for d in case we need to perform metadata + // validation during restore. + if err := d.updateFromGetattr(ctx); err != nil { + return err + } + } + if !d.readFile.isNil() || !d.writeFile.isNil() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: !d.readFile.isNil(), + write: !d.writeFile.isNil(), + } + } + d.dirMu.Lock() + defer d.dirMu.Unlock() + for _, child := range d.children { + if child != nil { + if err := child.prepareSaveRecursive(ctx); err != nil { + return err + } + } + } + return nil +} + +// beforeSave is invoked by stateify. +func (d *dentry) beforeSave() { + if d.vfsd.IsDead() { + panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d))) + } +} + +// afterLoad is invoked by stateify. +func (d *dentry) afterLoad() { + d.readFD = -1 + d.writeFD = -1 + d.mmapFD = -1 + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} + +// afterLoad is invoked by stateify. +func (d *dentryPlatformFile) afterLoad() { + if d.hostFileMapper.IsInited() { + // Ensure that we don't call d.hostFileMapper.Init() again. + d.hostFileMapperInitOnce.Do(func() {}) + } +} + +// afterLoad is invoked by stateify. +func (fd *specialFileFD) afterLoad() { + fd.handle.fd = -1 +} + +// CompleteRestore implements +// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore. +func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error { + fdmapv := ctx.Value(CtxRestoreServerFDMap) + if fdmapv == nil { + return fmt.Errorf("no server FD map available") + } + fdmap := fdmapv.(map[string]int) + fd, ok := fdmap[fs.iopts.UniqueID] + if !ok { + return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID) + } + fs.opts.fd = fd + if err := fs.dial(ctx); err != nil { + return err + } + fs.inoByQIDPath = make(map[uint64]uint64) + + // Restore the filesystem root. + ctx.UninterruptibleSleepStart(false) + attached, err := fs.client.Attach(fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return err + } + attachFile := p9file{attached} + qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) + if err != nil { + return err + } + if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { + return err + } + + // Restore remaining dentries. + if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil { + return err + } + + // Re-open handles for specialFileFDs. Unlike the initial open + // (dentry.openSpecialFile()), pipes are always opened without blocking; + // non-readable pipe FDs are opened last to ensure that they don't get + // ENXIO if another specialFileFD represents the read end of the same pipe. + // This is consistent with VFS1. + haveWriteOnlyPipes := false + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + haveWriteOnlyPipes = true + continue + } + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + if haveWriteOnlyPipes { + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + } + } + + // Discard state only required during restore. + fs.savedDentryRW = nil + + return nil +} + +func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error { + d.file = file + + // Gofers do not preserve QID across checkpoint/restore, so: + // + // - We must assume that the remote filesystem did not change in a way that + // would invalidate dentries, since we can't revalidate dentries by + // checking QIDs. + // + // - We need to associate the new QID.Path with the existing d.ino. + d.qidPath = qid.Path + d.fs.inoMu.Lock() + d.fs.inoByQIDPath[qid.Path] = d.ino + d.fs.inoMu.Unlock() + + // Check metadata stability before updating metadata. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + if d.isRegularFile() { + if opts.ValidateFileSizes { + if !attrMask.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d)) + } + if d.size != attr.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size) + } + } + if opts.ValidateFileModificationTimestamps { + if !attrMask.MTime { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d)) + } + if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want)) + } + } + } + if !d.cachedMetadataAuthoritative() { + d.updateFromP9AttrsLocked(attrMask, attr) + } + + if rw, ok := d.fs.savedDentryRW[d]; ok { + if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil { + return err + } + } + + return nil +} + +// Preconditions: d is not synthetic. +func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + for _, child := range d.children { + if child == nil { + continue + } + if _, ok := d.fs.syncableDentries[child]; !ok { + // child is synthetic. + continue + } + if err := child.restoreRecursive(ctx, opts); err != nil { + return err + } + } + return nil +} + +// Preconditions: d is not synthetic (but note that since this function +// restores d.file, d.file.isNil() is always true at this point, so this can +// only be detected by checking filesystem.syncableDentries). d.parent has been +// restored. +func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { + return err + } + return d.restoreDescendantsRecursive(ctx, opts) +} + +func (fd *specialFileFD) completeRestore(ctx context.Context) error { + d := fd.dentry() + h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + if err != nil { + return err + } + fd.handle = h + + ftype := d.fileType() + fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0 + if fd.haveQueue { + if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 326b940a7..a21199eac 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -42,9 +42,6 @@ type endpoint struct { // dentry is the filesystem dentry which produced this endpoint. dentry *dentry - // file is the p9 file that contains a single unopened fid. - file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - // path is the sentry path where this endpoint is bound. path string } @@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect } func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { - hostFile, err := e.file.Connect(flags) + hostFile, err := e.dentry.file.connect(ctx, flags) if err != nil { return nil, syserr.ErrConnectionRefused } @@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) if serr != nil { - log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr) + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr) return nil, serr } return c, nil diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 71581736c..089955a96 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -15,7 +15,6 @@ package gofer import ( - "sync" "sync/atomic" "syscall" @@ -24,7 +23,9 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -40,7 +41,7 @@ type specialFileFD struct { fileDescription // handle is used for file I/O. handle is immutable. - handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + handle handle `state:"nosave"` // isRegularFile is true if this FD represents a regular file which is only // possible when filesystemOptions.regularFilesUseSpecialFileFD is in @@ -54,15 +55,23 @@ type specialFileFD struct { // haveQueue is true if this file description represents a file for which // queue may send I/O readiness events. haveQueue is immutable. - haveQueue bool + haveQueue bool `state:"nosave"` queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex `state:"nosave"` off int64 + + // If haveBuf is non-zero, this FD represents a pipe, and buf contains data + // read from the pipe from previous calls to specialFileFD.savePipeData(). + // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic + // memory operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte } -func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { +func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) { ftype := d.fileType() seekable := ftype == linux.S_IFREG || ftype == linux.S_IFCHR || ftype == linux.S_IFBLK haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0 @@ -72,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, seekable: seekable, haveQueue: haveQueue, } - fd.LockFD.Init(locks) + fd.LockFD.Init(&d.locks) if haveQueue { if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { return nil, err @@ -87,6 +96,17 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, } return nil, err } + d.fs.syncMu.Lock() + d.fs.specialFileFDs[fd] = struct{}{} + d.fs.syncMu.Unlock() + if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { + fsmetric.GoferOpensWX.Increment() + } + if h.fd >= 0 { + fsmetric.GoferOpensHost.Increment() + } else { + fsmetric.GoferOpens9P.Increment() + } return fd, nil } @@ -150,6 +170,17 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + start := fsmetric.StartReadWait() + defer func() { + if fd.handle.fd >= 0 { + fsmetric.GoferReadsHost.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start) + } else { + fsmetric.GoferReads9P.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start) + } + }() + if fd.seekable && offset < 0 { return 0, syserror.EINVAL } @@ -161,26 +192,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs return 0, syserror.EOPNOTSUPP } - // Going through dst.CopyOutFrom() holds MM locks around file operations of - // unknown duration. For regularFileFD, doing so is necessary to support - // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't - // hold here since specialFileFD doesn't client-cache data. Just buffer the - // read instead. if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } + + bufN := int64(0) + if atomic.LoadUint32(&fd.haveBuf) != 0 { + var err error + fd.bufMu.Lock() + if len(fd.buf) != 0 { + var n int + n, err = dst.CopyOut(ctx, fd.buf) + dst = dst.DropFirst(n) + fd.buf = fd.buf[n:] + if len(fd.buf) == 0 { + atomic.StoreUint32(&fd.haveBuf, 0) + fd.buf = nil + } + bufN = int64(n) + if offset >= 0 { + offset += bufN + } + } + fd.bufMu.Unlock() + if err != nil { + return bufN, err + } + } + + // Going through dst.CopyOutFrom() would hold MM locks around file + // operations of unknown duration. For regularFileFD, doing so is necessary + // to support mmap due to lock ordering; MM locks precede dentry.dataMu. + // That doesn't hold here since specialFileFD doesn't client-cache data. + // Just buffer the read instead. buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } if n == 0 { - return 0, err + return bufN, err } if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil { - return int64(cp), cperr + return bufN + int64(cp), cperr } - return int64(n), err + return bufN + int64(n), err } // Read implements vfs.FileDescriptionImpl.Read. @@ -217,16 +273,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() - // If the regular file fd was opened with O_APPEND, make sure the file size - // is updated. There is a possible race here if size is modified externally - // after metadata cache is updated. - if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { - return 0, offset, err + if fd.isRegularFile { + // If the regular file fd was opened with O_APPEND, make sure the file + // size is updated. There is a possible race here if size is modified + // externally after metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } - } - if fd.isRegularFile { // We need to hold the metadataMu *while* writing to a regular file. d.metadataMu.Lock() defer d.metadataMu.Unlock() @@ -306,13 +362,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - // If we have a host FD, fsyncing it is likely to be faster than an fsync - // RPC. - if fd.handle.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(fd.handle.fd)) - ctx.UninterruptibleSleepFinish(false) - return err - } - return fd.handle.file.fsync(ctx) + return fd.sync(ctx, false /* forFilesystemSync */) +} + +func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error { + err := func() error { + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if fd.handle.fd >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(fd.handle.fd)) + ctx.UninterruptibleSleepFinish(false) + return err + } + return fd.handle.file.fsync(ctx) + }() + if err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (fd represents a regular file that was opened for writing). + if fd.isRegularFile && fd.vfsfd.IsWritable() { + return err + } + ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err) + } + return nil } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 7e825caae..9cbe805b9 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -17,7 +17,6 @@ package gofer import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 { return int64(s*1e9 + ns) } -func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 { - return ts.Sec*1e9 + int64(ts.Nsec) -} - -func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { - return linux.StatxTimestamp{ - Sec: ns / 1e9, - Nsec: uint32(ns % 1e9), - } -} - // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 56bcf9bdb..4ae9d6d5e 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "inode_refs.go", package = "host", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "connected_endpoint_refs.go", package = "host", prefix = "ConnectedEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ConnectedEndpoint", }, @@ -33,7 +33,7 @@ go_library( "host.go", "inode_refs.go", "ioctl_unsafe.go", - "mmap.go", + "save_restore.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -51,6 +51,7 @@ go_library( "//pkg/log", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go index 0135e4428..13ef48cb5 100644 --- a/pkg/sentry/fsimpl/host/control.go +++ b/pkg/sentry/fsimpl/host/control.go @@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription { } // Create the file backed by hostFD. - file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */) + file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{}) if err != nil { ctx.Warningf("Error creating file from host FD: %v", err) break diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 698e913fe..435a21d77 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -19,6 +19,7 @@ package host import ( "fmt" "math" + "sync/atomic" "syscall" "golang.org/x/sys/unix" @@ -40,34 +41,97 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) { - // Determine if hostFD is seekable. If not, this syscall will return ESPIPE - // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character - // devices. +// inode implements kernfs.Inode. +// +// +stateify savable +type inode struct { + kernfs.InodeNoStatFS + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + kernfs.CachedMappable + kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. + + locks vfs.FileLocks + + // When the reference count reaches zero, the host fd is closed. + inodeRefs + + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + hostFD int + + // ino is an inode number unique within this filesystem. + // + // This field is initialized at creation time and is immutable. + ino uint64 + + // ftype is the file's type (a linux.S_IFMT mask). + // + // This field is initialized at creation time and is immutable. + ftype uint16 + + // mayBlock is true if hostFD is non-blocking, and operations on it may + // return EAGAIN or EWOULDBLOCK instead of blocking. + // + // This field is initialized at creation time and is immutable. + mayBlock bool + + // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file + // offsets are meaningful iff seekable is true. + // + // This field is initialized at creation time and is immutable. + seekable bool + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool + + // savable is true if hostFD may be saved/restored by its numeric value. + // + // This field is initialized at creation time and is immutable. + savable bool + + // Event queue for blocking operations. + queue waiter.Queue + + // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data + // read from the pipe from previous calls to inode.beforeSave(). haveBuf + // and buf are protected by bufMu. haveBuf is accessed using atomic memory + // operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte +} + +func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) { + // Determine if hostFD is seekable. _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) seekable := err != syserror.ESPIPE + // We expect regular files to be seekable, as this is required for them to + // be memory-mappable. + if !seekable && fileType == syscall.S_IFREG { + ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD) + return nil, syserror.ESPIPE + } i := &inode{ - hostFD: hostFD, - ino: fs.NextIno(), - isTTY: isTTY, - wouldBlock: wouldBlock(uint32(fileType)), - seekable: seekable, - // NOTE(b/38213152): Technically, some obscure char devices can be memory - // mapped, but we only allow regular files. - canMap: fileType == linux.S_IFREG, - } - i.pf.inode = i - i.EnableLeakCheck() - - // Non-seekable files can't be memory mapped, assert this. - if !i.seekable && i.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - // If the hostFD would block, we must set it to non-blocking and handle - // blocking behavior in the sentry. - if i.wouldBlock { + hostFD: hostFD, + ino: fs.NextIno(), + ftype: uint16(fileType), + mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR, + seekable: seekable, + isTTY: isTTY, + savable: savable, + } + i.InitRefs() + i.CachedMappable.Init(hostFD) + + // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and + // handle blocking behavior in the sentry. + if i.mayBlock { if err := syscall.SetNonblock(i.hostFD, true); err != nil { return nil, err } @@ -80,6 +144,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) ( // NewFDOptions contains options to NewFD. type NewFDOptions struct { + // If Savable is true, the host file descriptor may be saved/restored by + // numeric value; the sandbox API requires a corresponding host FD with the + // same numeric value to be provieded at time of restore. + Savable bool + // If IsTTY is true, the file descriptor is a TTY. IsTTY bool @@ -114,7 +183,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) } d := &kernfs.Dentry{} - i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY) + i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY) if err != nil { return nil, err } @@ -132,7 +201,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // ImportFD sets up and returns a vfs.FileDescription from a donated fd. func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) { return NewFD(ctx, mnt, hostFD, &NewFDOptions{ - IsTTY: isTTY, + Savable: true, + IsTTY: isTTY, }) } @@ -191,68 +261,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } -// inode implements kernfs.Inode. -// -// +stateify savable -type inode struct { - kernfs.InodeNoStatFS - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. - - locks vfs.FileLocks - - // When the reference count reaches zero, the host fd is closed. - inodeRefs - - // hostFD contains the host fd that this file was originally created from, - // which must be available at time of restore. - // - // This field is initialized at creation time and is immutable. - hostFD int - - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 - - // isTTY is true if this file represents a TTY. - // - // This field is initialized at creation time and is immutable. - isTTY bool - - // seekable is false if the host fd points to a file representing a stream, - // e.g. a socket or a pipe. Such files are not seekable and can return - // EWOULDBLOCK for I/O operations. - // - // This field is initialized at creation time and is immutable. - seekable bool - - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. - // - // This field is initialized at creation time and is immutable. - wouldBlock bool - - // Event queue for blocking operations. - queue waiter.Queue - - // canMap specifies whether we allow the file to be memory mapped. - // - // This field is initialized at creation time and is immutable. - canMap bool - - // mapsMu protects mappings. - mapsMu sync.Mutex `state:"nosave"` - - // If canMap is true, mappings tracks mappings of hostFD into - // memmap.MappingSpaces. - mappings memmap.MappingSet - - // pf implements platform.File for mappings of hostFD. - pf inodePlatformFile -} - // CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t @@ -422,14 +430,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre oldpgend, _ := usermem.PageRoundUp(oldSize) newpgend, _ := usermem.PageRoundUp(s.Size) if oldpgend != newpgend { - i.mapsMu.Lock() - i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ - // Compare Linux's mm/truncate.c:truncate_setsize() => - // truncate_pagecache() => - // mm/memory.c:unmap_mapping_range(evencows=1). - InvalidatePrivate: true, - }) - i.mapsMu.Unlock() + i.CachedMappable.InvalidateRange(memmap.MappableRange{newpgend, oldpgend}) } } } @@ -448,7 +449,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre // DecRef implements kernfs.Inode.DecRef. func (i *inode) DecRef(ctx context.Context) { i.inodeRefs.DecRef(func() { - if i.wouldBlock { + if i.mayBlock { fdnotifier.RemoveFD(int32(i.hostFD)) } if err := unix.Close(i.hostFD); err != nil { @@ -567,6 +568,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin // PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { return 0, syserror.ESPIPE @@ -577,19 +585,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off // Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { + bufN, err := i.readFromBuf(ctx, &dst) + if err != nil { + return bufN, err + } n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags) + total := bufN + n if isBlockError(err) { // If we got any data at all, return it as a "completed" partial read // rather than retrying until complete. - if n != 0 { + if total != 0 { err = nil } else { err = syserror.ErrWouldBlock } } - return n, err + return total, err } f.offsetMu.Lock() @@ -599,13 +619,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } -func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // Check that flags are supported. - // - // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. - if flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP +func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) { + if atomic.LoadUint32(&i.haveBuf) == 0 { + return 0, nil + } + i.bufMu.Lock() + defer i.bufMu.Unlock() + if len(i.buf) == 0 { + return 0, nil } + n, err := dst.CopyOut(ctx, i.buf) + *dst = dst.DropFirst(n) + i.buf = i.buf[n:] + if len(i.buf) == 0 { + atomic.StoreUint32(&i.haveBuf, 0) + i.buf = nil + } + return int64(n), err +} + +func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) n, err := dst.CopyOutFrom(ctx, reader) hostfd.PutReadWriterAt(reader) @@ -735,31 +768,37 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i } // Sync implements vfs.FileDescriptionImpl.Sync. -func (f *fileDescription) Sync(context.Context) error { +func (f *fileDescription) Sync(ctx context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { - if !f.inode.canMap { + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + if f.inode.ftype != syscall.S_IFREG { return syserror.ENODEV } i := f.inode - i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) + i.CachedMappable.InitFileMapperOnce() return vfs.GenericConfigureMMap(&f.vfsfd, i, opts) } // EventRegister implements waiter.Waitable.EventRegister. func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { f.inode.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // EventUnregister implements waiter.Waitable.EventUnregister. func (f *fileDescription) EventUnregister(e *waiter.Entry) { f.inode.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // Readiness uses the poll() syscall to check the status of the underlying FD. diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go new file mode 100644 index 000000000..8800652a9 --- /dev/null +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -0,0 +1,70 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "fmt" + "io" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/usermem" +) + +// beforeSave is invoked by stateify. +func (i *inode) beforeSave() { + if !i.savable { + panic("host.inode is not savable") + } + if i.ftype == syscall.S_IFIFO { + // If this pipe FD is readable, drain it so that bytes in the pipe can + // be read after restore. (This is a legacy VFS1 feature.) We don't + // know if the pipe FD is readable, so just try reading and tolerate + // EBADF from the read. + i.bufMu.Lock() + defer i.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */) + if n != 0 { + i.buf = append(i.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF { + break + } + panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err)) + } + } + if len(i.buf) != 0 { + atomic.StoreUint32(&i.haveBuf, 1) + } + } +} + +// afterLoad is invoked by stateify. +func (i *inode) afterLoad() { + if i.mayBlock { + if err := syscall.SetNonblock(i.hostFD, true); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err)) + } + if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err)) + } + } +} diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 8a447e29f..60acc367f 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -84,6 +84,8 @@ type ConnectedEndpoint struct { // init performs initialization required for creating new ConnectedEndpoints and // for restoring them. func (c *ConnectedEndpoint) init() *syserr.Error { + c.InitRefs() + family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN) if err != nil { return syserr.FromError(err) @@ -132,7 +134,6 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable // ConnectedEndpointRefs start off with a single reference. We need two. e.IncRef() - e.EnableLeakCheck() return &e, nil } @@ -376,8 +377,7 @@ func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr s return nil, err } - // ConnectedEndpointRefs start off with a single reference. We need two. + // e starts off with a single reference. We need two. e.IncRef() - e.EnableLeakCheck() return &e, nil } diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 412bdb2eb..b2f43a119 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)} } -// wouldBlock returns true for file types that can return EWOULDBLOCK -// for blocking operations, e.g. pipes, character devices, and sockets. -func wouldBlock(fileType uint32) bool { - return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK -} - // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 858cc24ce..6dbc7e34d 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) go_template_instance( + name = "dentry_list", + out = "dentry_list.go", + package = "kernfs", + prefix = "dentry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Dentry", + "Linker": "*Dentry", + }, +) + +go_template_instance( name = "fstree", out = "fstree.go", package = "kernfs", @@ -27,22 +39,11 @@ go_template_instance( ) go_template_instance( - name = "dentry_refs", - out = "dentry_refs.go", - package = "kernfs", - prefix = "Dentry", - template = "//pkg/refs_vfs2:refs_template", - types = { - "T": "Dentry", - }, -) - -go_template_instance( name = "static_directory_refs", out = "static_directory_refs.go", package = "kernfs", prefix = "StaticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "StaticDirectory", }, @@ -53,7 +54,7 @@ go_template_instance( out = "dir_refs.go", package = "kernfs_test", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -64,7 +65,7 @@ go_template_instance( out = "readonly_dir_refs.go", package = "kernfs_test", prefix = "readonlyDir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "readonlyDir", }, @@ -75,7 +76,7 @@ go_template_instance( out = "synthetic_directory_refs.go", package = "kernfs", prefix = "syntheticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "syntheticDirectory", }, @@ -84,13 +85,15 @@ go_template_instance( go_library( name = "kernfs", srcs = [ - "dentry_refs.go", + "dentry_list.go", "dynamic_bytes_file.go", "fd_impl_util.go", "filesystem.go", "fstree.go", "inode_impl_util.go", "kernfs.go", + "mmap_util.go", + "save_restore.go", "slot_list.go", "static_directory_refs.go", "symlink.go", @@ -104,8 +107,12 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/safemem", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", @@ -129,6 +136,7 @@ go_test( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index b929118b1..485504995 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -47,11 +47,11 @@ type DynamicBytesFile struct { var _ Inode = (*DynamicBytesFile)(nil) // Init initializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { +func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) f.data = data } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index abf1905d6..f8dae22f8 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { return fd.vfsfd.VirtualDentry().Mount().Filesystem() } +func (fd *GenericDirectoryFD) dentry() *Dentry { + return fd.vfsfd.Dentry().Impl().(*Dentry) +} + func (fd *GenericDirectoryFD) inode() Inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return fd.dentry().inode } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds @@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // Handle "..". if fd.off == 1 { - vfsd := fd.vfsfd.VirtualDentry().Dentry() - parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode + parentInode := genericParentOrSelf(fd.dentry()).inode stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err @@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent var err error relOffset := fd.off - int64(len(fd.children.set)) - 2 - fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset) + fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset) return err } @@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(ctx, fd.filesystem(), creds, opts) + return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts) } // Allocate implements vfs.FileDescriptionImpl.Allocate. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 6426a55f6..e77523f22 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -207,24 +207,23 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: // * Filesystem.mu must be locked for at least reading. // * isDir(parentInode) == true. -func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) { - if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return "", err +func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error { + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err } - pc := rp.Component() - if pc == "." || pc == ".." { - return "", syserror.EEXIST + if name == "." || name == ".." { + return syserror.EEXIST } - if len(pc) > linux.NAME_MAX { - return "", syserror.ENAMETOOLONG + if len(name) > linux.NAME_MAX { + return syserror.ENAMETOOLONG } - if _, ok := parent.children[pc]; ok { - return "", syserror.EEXIST + if _, ok := parent.children[name]; ok { + return syserror.EEXIST } if parent.VFSDentry().IsDead() { - return "", syserror.ENOENT + return syserror.ENOENT } - return pc, nil + return nil } // checkDeleteLocked checks that the file represented by vfsd may be deleted. @@ -245,7 +244,41 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) er } // Release implements vfs.FilesystemImpl.Release. -func (fs *Filesystem) Release(context.Context) { +func (fs *Filesystem) Release(ctx context.Context) { + root := fs.root + if root == nil { + return + } + fs.mu.Lock() + root.releaseKeptDentriesLocked(ctx) + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } + fs.mu.Unlock() + // Drop ref acquired in Dentry.InitRoot(). + root.DecRef(ctx) +} + +// releaseKeptDentriesLocked recursively drops all dentry references created by +// Lookup when Dentry.inode.Keep() is true. +// +// Precondition: Filesystem.mu is held. +func (d *Dentry) releaseKeptDentriesLocked(ctx context.Context) { + if d.inode.Keep() && d != d.fs.root { + d.decRefLocked(ctx) + } + + if d.isDir() { + var children []*Dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + child.releaseKeptDentriesLocked(ctx) + } + } } // Sync implements vfs.FilesystemImpl.Sync. @@ -318,10 +351,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if rp.Mount() != vd.Mount() { return syserror.EXDEV } @@ -360,8 +396,8 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -373,7 +409,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err } - childI = newSyntheticDirectory(rp.Credentials(), opts.Mode) + childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode) } var child Dentry child.Init(fs, childI) @@ -396,10 +432,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } @@ -517,9 +556,6 @@ afterTrailingSymlink: } var child Dentry child.Init(fs, childI) - // FIXME(gvisor.dev/issue/1193): Race between checking existence with - // fs.stepExistingLocked and parent.insertChild. If possible, we should hold - // dirMu from one to the other. parent.insertChild(pc, &child) // Open may block so we need to unlock fs.mu. IncRef child to prevent // its destruction while fs.mu is unlocked. @@ -626,8 +662,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Can we create the dst dentry? var dst *Dentry - pc, err := checkCreateLocked(ctx, rp, dstDir) - switch err { + pc := rp.Component() + switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err { case nil: // Ok, continue with rename as replacement. case syserror.EEXIST: @@ -791,10 +827,13 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 122b10591..eac578f25 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -21,9 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // InodeNoopRefCount partially implements the Inode interface, specifically the @@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error) } // IterDirents implements Inode.IterDirents. -func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { panic("IterDirents called on non-directory inode") } @@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // // +stateify savable type InodeAttrs struct { - devMajor uint32 - devMinor uint32 - ino uint64 - mode uint32 - uid uint32 - gid uint32 - nlink uint32 + devMajor uint32 + devMinor uint32 + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 + blockSize uint32 + + // Timestamps, all nsecs from the Unix epoch. + atime int64 + mtime int64 + ctime int64 } // Init initializes this InodeAttrs. -func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { +func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { if mode.FileType() == 0 { panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) } @@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) atomic.StoreUint32(&a.nlink, nlink) + atomic.StoreUint32(&a.blockSize, usermem.PageSize) + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.atime, now) + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) } // DevMajor returns the device major number. @@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode { return linux.FileMode(atomic.LoadUint32(&a.mode)) } +// TouchAtime updates a.atime to the current time. +func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) { + if mnt.Flags.NoATime || mnt.ReadOnly() { + return + } + if err := mnt.CheckBeginWrite(); err != nil { + return + } + atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds()) + mnt.EndWrite() +} + +// TouchCMtime updates a.{c/m}time to the current time. The caller should +// synchronize calls to this so that ctime and mtime are updated to the same +// value. +func (a *InodeAttrs) TouchCMtime(ctx context.Context) { + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) +} + // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME stat.DevMajor = a.devMajor stat.DevMinor = a.devMinor stat.Ino = atomic.LoadUint64(&a.ino) @@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li stat.UID = atomic.LoadUint32(&a.uid) stat.GID = atomic.LoadUint32(&a.gid) stat.Nlink = atomic.LoadUint32(&a.nlink) - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - + stat.Blksize = atomic.LoadUint32(&a.blockSize) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime)) return stat, nil } // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - return a.SetInodeStat(ctx, fs, creds, opts) -} - -// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. -// This function can be used by other kernfs-based filesystem implementation to -// sets the unexported attributes into InodeAttrs. -func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } @@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds // inode numbers are immutable after node creation. Setting the size is often // allowed by kernfs files but does not do anything. If some other behavior is // needed, the embedder should consider extending SetStat. - // - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { @@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds atomic.StoreUint32(&a.gid, stat.GID) } + now := ktime.NowFromContext(ctx).Nanoseconds() + if stat.Mask&linux.STATX_ATIME != 0 { + if stat.Atime.Nsec == linux.UTIME_NOW { + stat.Atime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.atime, stat.Atime.ToNsec()) + } + if stat.Mask&linux.STATX_MTIME != 0 { + if stat.Mtime.Nsec == linux.UTIME_NOW { + stat.Mtime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec()) + } + return nil } @@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error } // IterDirents implements Inode.IterDirents. -func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { // All entries from OrderedChildren have already been handled in // GenericDirectoryFD.IterDirents. return offset, nil @@ -528,13 +568,6 @@ func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) e return o.Unlink(ctx, name, child) } -// +stateify savable -type renameAcrossDifferentImplementationsError struct{} - -func (renameAcrossDifferentImplementationsError) Error() string { - return "rename across inodes with different implementations" -} - // Rename implements Inode.Rename. // // Precondition: Rename may only be called across two directory inodes with @@ -545,13 +578,18 @@ func (renameAcrossDifferentImplementationsError) Error() string { // // Postcondition: reference on any replaced dentry transferred to caller. func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error { + if !o.writable { + return syserror.EPERM + } + dst, ok := dstDir.(interface{}).(*OrderedChildren) if !ok { - return renameAcrossDifferentImplementationsError{} + return syserror.EXDEV } - if !o.writable || !dst.writable { + if !dst.writable { return syserror.EPERM } + // Note: There's a potential deadlock below if concurrent calls to Rename // refer to the same src and dst directories in reverse. We avoid any // ordering issues because the caller is required to serialize concurrent @@ -619,10 +657,10 @@ type StaticDirectory struct { var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { +func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { inode := &StaticDirectory{} - inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) - inode.EnableLeakCheck() + inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts) + inode.InitRefs() inode.OrderedChildren.Init(OrderedChildrenOptions{}) links := inode.OrderedChildren.Populate(children) @@ -632,12 +670,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { +func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } s.fdOpts = fdOpts - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } // Open implements Inode.Open. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 606081e68..565d723f0 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -107,6 +108,23 @@ type Filesystem struct { // nextInoMinusOne is used to to allocate inode numbers on this // filesystem. Must be accessed by atomic operations. nextInoMinusOne uint64 + + // cachedDentries contains all dentries with 0 references. (Due to race + // conditions, it may also contain dentries with non-zero references.) + // cachedDentriesLen is the number of dentries in cachedDentries. These + // fields are protected by mu. + cachedDentries dentryList + cachedDentriesLen uint64 + + // MaxCachedDentries is the maximum size of cachedDentries. If not set, + // defaults to 0 and kernfs does not cache any dentries. This is immutable. + MaxCachedDentries uint64 + + // root is the root dentry of this filesystem. Note that root may be nil for + // filesystems on a disconnected mount without a root (e.g. pipefs, sockfs, + // hostfs). Filesystem holds an extra reference on root to prevent it from + // being destroyed prematurely. This is immutable. + root *Dentry } // deferDecRef defers dropping a dentry ref until the next call to @@ -165,7 +183,12 @@ const ( // +stateify savable type Dentry struct { vfsd vfs.Dentry - DentryRefs + + // refs is the reference count. When refs reaches 0, the dentry may be + // added to the cache or destroyed. If refs == -1, the dentry has already + // been destroyed. refs are allowed to go to 0 and increase again. refs is + // accessed using atomic memory operations. + refs int64 // fs is the owning filesystem. fs is immutable. fs *Filesystem @@ -177,6 +200,12 @@ type Dentry struct { parent *Dentry name string + // If cached is true, dentryEntry links dentry into + // Filesystem.cachedDentries. cached and dentryEntry are protected by + // Filesystem.mu. + cached bool + dentryEntry + // dirMu protects children and the names of child Dentries. // // Note that holding fs.mu for writing is not sufficient; @@ -188,6 +217,209 @@ type Dentry struct { inode Inode } +// IncRef implements vfs.DentryImpl.IncRef. +func (d *Dentry) IncRef() { + // d.refs may be 0 if d.fs.mu is locked, which serializes against + // d.cacheLocked(). + r := atomic.AddInt64(&d.refs, 1) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } +} + +// TryIncRef implements vfs.DentryImpl.TryIncRef. +func (d *Dentry) TryIncRef() bool { + for { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } + return true + } + } +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } + if r == 0 { + d.fs.mu.Lock() + d.cacheLocked(ctx) + d.fs.mu.Unlock() + } else if r < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +func (d *Dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } + if r == 0 { + d.cacheLocked(ctx) + } else if r < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +// cacheLocked should be called after d's reference count becomes 0. The ref +// count check may happen before acquiring d.fs.mu so there might be a race +// condition where the ref count is increased again by the time the caller +// acquires d.fs.mu. This race is handled. +// Only reachable dentries are added to the cache. However, a dentry might +// become unreachable *while* it is in the cache due to invalidation. +// +// Preconditions: d.fs.mu must be locked for writing. +func (d *Dentry) cacheLocked(ctx context.Context) { + // Dentries with a non-zero reference count must be retained. (The only way + // to obtain a reference on a dentry with zero references is via path + // resolution, which requires d.fs.mu, so if d.refs is zero then it will + // remain zero while we hold d.fs.mu for writing.) + refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + return + } + if refs > 0 { + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + return + } + // If the dentry is deleted and invalidated or has no parent, then it is no + // longer reachable by path resolution and should be dropped immediately + // because it has zero references. + // Note that a dentry may not always have a parent; for example magic links + // as described in Inode.Getlink. + if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil { + if !isDead { + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) + } + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + d.destroyLocked(ctx) + return + } + // If d is already cached, just move it to the front of the LRU. + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentries.PushFront(d) + return + } + // Cache the dentry, then evict the least recently used cached dentry if + // the cache becomes over-full. + d.fs.cachedDentries.PushFront(d) + d.fs.cachedDentriesLen++ + d.cached = true + if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries { + return + } + d.fs.evictCachedDentryLocked(ctx) + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.opts.maxCachedDentries, so we don't loop. +} + +// Preconditions: +// * fs.mu must be locked for writing. +// * fs.cachedDentriesLen != 0. +func (fs *Filesystem) evictCachedDentryLocked(ctx context.Context) { + // Evict the least recently used dentry because cache size is greater than + // max cache size (configured on mount). + victim := fs.cachedDentries.Back() + fs.cachedDentries.Remove(victim) + fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // after it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if !victim.vfsd.IsDead() { + victim.parent.dirMu.Lock() + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry()) + delete(victim.parent.children, victim.name) + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.MaxCachedDentries, so we don't loop. +} + +// destroyLocked destroys the dentry. +// +// Preconditions: +// * d.fs.mu must be locked for writing. +// * d.refs == 0. +// * d should have been removed from d.parent.children, i.e. d is not reachable +// by path traversal. +// * d.vfsd.IsDead() is true. +func (d *Dentry) destroyLocked(ctx context.Context) { + refs := atomic.LoadInt64(&d.refs) + switch refs { + case 0: + // Mark the dentry destroyed. + atomic.StoreInt64(&d.refs, -1) + case -1: + panic("dentry.destroyLocked() called on already destroyed dentry") + default: + panic("dentry.destroyLocked() called with references on the dentry") + } + + d.inode.DecRef(ctx) // IncRef from Init. + d.inode = nil + + if d.parent != nil { + d.parent.decRefLocked(ctx) + } + + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *Dentry) RefType() string { + return "kernfs.Dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *Dentry) LeakMessage() string { + return fmt.Sprintf("[kernfs.Dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *Dentry) LogRefs() bool { + return false +} + +// InitRoot initializes this dentry as the root of the filesystem. +// +// Precondition: Caller must hold a reference on inode. +// +// Postcondition: Caller's reference on inode is transferred to the dentry. +func (d *Dentry) InitRoot(fs *Filesystem, inode Inode) { + d.Init(fs, inode) + fs.root = d + // Hold an extra reference on the root dentry. It is held by fs to prevent the + // root from being "cached" and subsequently evicted. + d.IncRef() +} + // Init initializes this dentry. // // Precondition: Caller must hold a reference on inode. @@ -197,6 +429,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { d.vfsd.Init(d) d.fs = fs d.inode = inode + atomic.StoreInt64(&d.refs, 1) ftype := inode.Mode().FileType() if ftype == linux.ModeDirectory { d.flags |= dflagsIsDir @@ -204,7 +437,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { if ftype == linux.ModeSymlink { d.flags |= dflagsIsSymlink } - d.EnableLeakCheck() + refsvfs2.Register(d) } // VFSDentry returns the generic vfs dentry for this kernfs dentry. @@ -222,32 +455,6 @@ func (d *Dentry) isSymlink() bool { return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 } -// DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef(ctx context.Context) { - decRefParent := false - d.fs.mu.Lock() - d.DentryRefs.DecRef(func() { - d.inode.DecRef(ctx) // IncRef from Init. - d.inode = nil - if d.parent != nil { - // We will DecRef d.parent once all locks are dropped. - decRefParent = true - d.parent.dirMu.Lock() - // Remove d from parent.children. It might already have been - // removed due to invalidation. - if _, ok := d.parent.children[d.name]; ok { - delete(d.parent.children, d.name) - d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) - } - d.parent.dirMu.Unlock() - } - }) - d.fs.mu.Unlock() - if decRefParent { - d.parent.DecRef(ctx) // IncRef from Dentry.insertChild. - } -} - // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // // Although Linux technically supports inotify on pseudo filesystems (inotify @@ -267,7 +474,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {} // this dentry. This does not update the directory inode, so calling this on its // own isn't sufficient to insert a child into a directory. // -// Precondition: d must represent a directory inode. +// Preconditions: +// * d must represent a directory inode. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChild(name string, child *Dentry) { d.dirMu.Lock() d.insertChildLocked(name, child) @@ -280,6 +489,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) { // Preconditions: // * d must represent a directory inode. // * d.dirMu must be locked. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChildLocked(name string, child *Dentry) { if !d.isDir() { panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d)) @@ -436,7 +646,7 @@ type inodeDirectory interface { // the inode is a directory. // // The child returned by Lookup will be hashed into the VFS dentry tree, - // atleast for the duration of the current FS operation. + // at least for the duration of the current FS operation. // // Lookup must return the child with an extra reference whose ownership is // transferred to the dentry that is created to point to that inode. If @@ -454,7 +664,7 @@ type inodeDirectory interface { // inside the entries returned by this IterDirents invocation. In other words, // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as // the return value, while 'relOffset' is the place to start iteration. - IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) + IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } type inodeSymlink interface { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 82fa19c03..e63588e33 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file." // RootDentryFn is a generator function for creating the root dentry of a test // filesystem. See newTestSystem. -type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode +type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode // newTestSystem sets up a minimal environment for running a test, including an // instance of a test filesystem. Tests can control the contents of the @@ -72,10 +72,10 @@ type file struct { content string } -func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode { +func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode { f := &file{} f.content = content - f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) + f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) return f } @@ -105,11 +105,11 @@ type readonlyDir struct { locks vfs.FileLocks } -func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &readonlyDir{} - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - dir.EnableLeakCheck() + dir.InitRefs() dir.IncLinks(dir.OrderedChildren.Populate(contents)) return dir } @@ -142,12 +142,12 @@ type dir struct { fs *filesystem } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &dir{} dir.fs = fs - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) - dir.EnableLeakCheck() + dir.InitRefs() dir.IncLinks(dir.OrderedChildren.Populate(contents)) return dir @@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) { func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - dir := d.fs.newDir(creds, opts.Mode, nil) + dir := d.fs.newDir(ctx, creds, opts.Mode, nil) if err := d.OrderedChildren.Insert(name, dir); err != nil { dir.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) d.IncLinks(1) return dir, nil } func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - f := d.fs.newFile(creds, "") + f := d.fs.newFile(ctx, creds, "") if err := d.OrderedChildren.Insert(name, f); err != nil { f.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) return f, nil } @@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {} func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { fs := &filesystem{} fs.VFSFilesystem().Init(vfsObj, &fst, fs) - root := fst.rootFn(creds, fs) + root := fst.rootFn(ctx, creds, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, root) return fs.VFSFilesystem(), d.VFSDentry(), nil @@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst // -------------------- Remainder of the file are test cases -------------------- func TestBasic(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -228,9 +230,9 @@ func TestBasic(t *testing.T) { } func TestMkdirGetDentry(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) { } func TestReadStaticFile(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) { } func TestCreateNewFileInStaticDir(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) { } func TestDirFDReadWrite(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, nil) + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, nil) }) defer sys.Destroy() @@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) { } func TestDirFDIterDirents(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ // Fill root with nodes backed by various inode implementations. - "dir1": fs.newReadonlyDir(creds, 0755, nil), - "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{ - "dir3": fs.newDir(creds, 0755, nil), + "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil), + "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir3": fs.newDir(ctx, creds, 0755, nil), }), - "file1": fs.newFile(creds, staticFileContent), + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/kernfs/mmap_util.go index b51a17bed..bd6a134b4 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/kernfs/mmap_util.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package host +package kernfs import ( "gvisor.dev/gvisor/pkg/context" @@ -26,11 +26,14 @@ import ( // inodePlatformFile implements memmap.File. It exists solely because inode // cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // -// inodePlatformFile should only be used if inode.canMap is true. -// // +stateify savable type inodePlatformFile struct { - *inode + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + // inodePlatformFile does not own hostFD and hence should not close it. + hostFD int // fdRefsMu protects fdRefs. fdRefsMu sync.Mutex `state:"nosave"` @@ -43,12 +46,12 @@ type inodePlatformFile struct { fileMapper fsutil.HostFileMapper // fileMapperInitOnce is used to lazily initialize fileMapper. - fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + fileMapperInitOnce sync.Once `state:"nosave"` } +var _ memmap.File = (*inodePlatformFile)(nil) + // IncRef implements memmap.File.IncRef. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) @@ -56,8 +59,6 @@ func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { } // DecRef implements memmap.File.DecRef. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) @@ -65,8 +66,6 @@ func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { } // MapInternal implements memmap.File.MapInternal. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } @@ -76,10 +75,32 @@ func (i *inodePlatformFile) FD() int { return i.hostFD } -// AddMapping implements memmap.Mappable.AddMapping. +// CachedMappable implements memmap.Mappable. This utility can be embedded in a +// kernfs.Inode that represents a host file to make the inode mappable. +// CachedMappable caches the mappings of the host file. CachedMappable must be +// initialized (via Init) with a hostFD before use. // -// Precondition: i.inode.canMap must be true. -func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +// +stateify savable +type CachedMappable struct { + // mapsMu protects mappings. + mapsMu sync.Mutex `state:"nosave"` + + // mappings tracks mappings of hostFD into memmap.MappingSpaces. + mappings memmap.MappingSet + + // pf implements memmap.File for mappings backed by a host fd. + pf inodePlatformFile +} + +var _ memmap.Mappable = (*CachedMappable)(nil) + +// Init initializes i.pf. This must be called before using CachedMappable. +func (i *CachedMappable) Init(hostFD int) { + i.pf.hostFD = hostFD +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { i.mapsMu.Lock() mapped := i.mappings.AddMapping(ms, ar, offset, writable) for _, r := range mapped { @@ -90,9 +111,7 @@ func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar userm } // RemoveMapping implements memmap.Mappable.RemoveMapping. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { i.mapsMu.Lock() unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable) for _, r := range unmapped { @@ -102,16 +121,12 @@ func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar us } // CopyMapping implements memmap.Mappable.CopyMapping. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { return i.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { mr := optional return []memmap.Translation{ { @@ -124,10 +139,26 @@ func (i *inode) Translate(ctx context.Context, required, optional memmap.Mappabl } // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) InvalidateUnsavable(ctx context.Context) error { +func (i *CachedMappable) InvalidateUnsavable(ctx context.Context) error { // We expect the same host fd across save/restore, so all translations // should be valid. return nil } + +// InvalidateRange invalidates the passed range on i.mappings. +func (i *CachedMappable) InvalidateRange(r memmap.MappableRange) { + i.mapsMu.Lock() + i.mappings.Invalidate(r, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + i.mapsMu.Unlock() +} + +// InitFileMapperOnce initializes the host file mapper. It ensures that the +// file mapper is initialized just once. +func (i *CachedMappable) InitFileMapperOnce() { + i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) +} diff --git a/pkg/sentry/fsimpl/kernfs/save_restore.go b/pkg/sentry/fsimpl/kernfs/save_restore.go new file mode 100644 index 000000000..f78509eb7 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/save_restore.go @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// afterLoad is invoked by stateify. +func (d *Dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) >= 0 { + refsvfs2.Register(d) + } +} + +// afterLoad is invoked by stateify. +func (i *inodePlatformFile) afterLoad() { + if i.fileMapper.IsInited() { + // Ensure that we don't call i.fileMapper.Init() again. + i.fileMapperInitOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 934cc6c9e..a0736c0d6 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -38,16 +38,16 @@ type StaticSymlink struct { var _ Inode = (*StaticSymlink)(nil) // NewStaticSymlink creates a new symlink file pointing to 'target'. -func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { +func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { inode := &StaticSymlink{} - inode.Init(creds, devMajor, devMinor, ino, target) + inode.Init(ctx, creds, devMajor, devMinor, ino, target) return inode } // Init initializes the instance. -func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { +func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { s.target = target - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } // Readlink implements Inode.Readlink. diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go index d0ed17b18..463d77d79 100644 --- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -41,17 +41,17 @@ type syntheticDirectory struct { var _ Inode = (*syntheticDirectory)(nil) -func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode { +func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode { inode := &syntheticDirectory{} - inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) return inode } -func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) } - dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) dir.OrderedChildren.Init(OrderedChildrenOptions{ Writable: true, }) @@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs if !opts.ForSyntheticMountpoint { return nil, syserror.EPERM } - subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) if err := dir.OrderedChildren.Insert(name, subdirI); err != nil { subdirI.DecRef(ctx) return nil, err } + dir.TouchCMtime(ctx) return subdirI, nil } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 1e11b0428..bf13bbbf4 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -23,6 +23,7 @@ go_library( "fstree.go", "overlay.go", "regular_file.go", + "save_restore.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -30,6 +31,8 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 4506642ca..27b00cf6f 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -16,7 +16,6 @@ package overlay import ( "fmt" - "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return err } defer newFD.DecRef(ctx) - bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size - for { - readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{}) - if readErr != nil && readErr != io.EOF { - cleanupUndoCopyUp() - return readErr - } - total := int64(0) - for total < readN { - writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{}) - total += writeN - if writeErr != nil { - cleanupUndoCopyUp() - return writeErr - } - } - if readErr == io.EOF { - break - } + if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil { + cleanupUndoCopyUp() + return err } d.mapsMu.Lock() defer d.mapsMu.Unlock() @@ -409,7 +392,7 @@ func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) er if dirent.Name == "." || dirent.Name == ".." { continue } - child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) + child, _, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 78a01bbb7..d55bdc97f 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -77,26 +78,36 @@ func putDentrySlice(ds *[]*dentry) { } // renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls -// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for +// dentry.checkDropLocked on all dentries in *dsp with fs.renameMu locked for // writing. // -// ds is a pointer-to-pointer since defer evaluates its arguments immediately, +// dsp is a pointer-to-pointer since defer evaluates its arguments immediately, // but dentry slices are allocated lazily, and it's much easier to say "defer // fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() { // fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this. -func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { +func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, dsp **[]*dentry) { fs.renameMu.RUnlock() - if *ds == nil { + if *dsp == nil { return } - if len(**ds) != 0 { + ds := **dsp + // Only go through calling dentry.checkDropLocked() (which requires + // re-locking renameMu) if we actually have any dentries with zero refs. + checkAny := false + for i := range ds { + if atomic.LoadInt64(&ds[i].refs) == 0 { + checkAny = true + break + } + } + if checkAny { fs.renameMu.Lock() - for _, d := range **ds { + for _, d := range ds { d.checkDropLocked(ctx) } fs.renameMu.Unlock() } - putDentrySlice(*ds) + putDentrySlice(*dsp) } func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) { @@ -121,63 +132,63 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, lookupLayer, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, lookupLayerNone, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, lookupLayerNone, err } afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } rp.Advance() - return d.parent, nil + return d.parent, d.parent.topLookupLayer(), nil } - child, err := fs.getChildLocked(ctx, d, name, ds) + child, topLookupLayer, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err + return nil, topLookupLayer, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx) if err != nil { - return nil, err + return nil, lookupLayerNone, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, topLookupLayer, err } goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, lookupLayer, error) { if child, ok := parent.children[name]; ok { - return child, nil + return child, child.topLookupLayer(), nil } - child, err := fs.lookupLocked(ctx, parent, name) + child, topLookupLayer, err := fs.lookupLocked(ctx, parent, name) if err != nil { - return nil, err + return nil, topLookupLayer, err } if parent.children == nil { parent.children = make(map[string]*dentry) @@ -185,16 +196,16 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s parent.children[name] = child // child's refcount is initially 0, so it may be dropped after traversal. *ds = appendDentry(*ds, child) - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. -func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { +func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, lookupLayer, error) { childPath := fspath.Parse(name) child := fs.newDentry() - existsOnAnyLayer := false + topLookupLayer := lookupLayerNone var lookupErr error vfsObj := fs.vfsfs.VirtualFilesystem() @@ -215,7 +226,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str defer childVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE) - if !existsOnAnyLayer { + if topLookupLayer == lookupLayerNone { // Mode, UID, GID, and (for non-directories) inode number come from // the topmost layer on which the file exists. mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO @@ -238,10 +249,13 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if isWhiteout(&stat) { // This is a whiteout, so it "doesn't exist" on this layer, and // layers below this one are ignored. + if isUpper { + topLookupLayer = lookupLayerUpperWhiteout + } return false } isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR - if existsOnAnyLayer && !isDir { + if topLookupLayer != lookupLayerNone && !isDir { // Directories are not merged with non-directory files from lower // layers; instead, layers including and below the first // non-directory file are ignored. (This file must be a directory @@ -258,8 +272,12 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } else { child.lowerVDs = append(child.lowerVDs, childVD) } - if !existsOnAnyLayer { - existsOnAnyLayer = true + if topLookupLayer == lookupLayerNone { + if isUpper { + topLookupLayer = lookupLayerUpper + } else { + topLookupLayer = lookupLayerLower + } child.mode = uint32(stat.Mode) child.uid = stat.UID child.gid = stat.GID @@ -288,11 +306,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if lookupErr != nil { child.destroyLocked(ctx) - return nil, lookupErr + return nil, topLookupLayer, lookupErr } - if !existsOnAnyLayer { + if !topLookupLayer.existsInOverlay() { child.destroyLocked(ctx) - return nil, syserror.ENOENT + return nil, topLookupLayer, syserror.ENOENT } // Device and inode numbers were copied from the topmost layer above; @@ -302,14 +320,20 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str child.devMinor = fs.dirDevMinor child.ino = fs.newDirIno() } else if !child.upperVD.Ok() { + childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor) + if err != nil { + ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err) + child.destroyLocked(ctx) + return nil, topLookupLayer, err + } child.devMajor = linux.UNNAMED_MAJOR - child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()] + child.devMinor = childDevMinor } parent.IncRef() child.parent = parent child.name = name - return child, nil + return child, topLookupLayer, nil } // lookupLayerLocked is similar to lookupLocked, but only returns information @@ -408,7 +432,7 @@ func (ll lookupLayer) existsInOverlay() bool { func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -428,7 +452,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, d := rp.Start().Impl().(*dentry) for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -463,9 +487,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if name == "." || name == ".." { return syserror.EEXIST } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.vfsd.IsDead() { return syserror.ENOENT } @@ -489,6 +510,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } + // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. if err := parent.copyUpLocked(ctx); err != nil { @@ -791,9 +816,9 @@ afterTrailingSymlink: } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { - fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds) + fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout) parent.dirMu.Unlock() return fd, err } @@ -893,7 +918,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * // Preconditions: // * parent.dirMu must be locked. // * parent does not already contain a child named rp.Component(). -func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { +func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry, haveUpperWhiteout bool) (*vfs.FileDescription, error) { creds := rp.Credentials() if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil { return nil, err @@ -918,19 +943,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving Start: parent.upperVD, Path: fspath.Parse(childName), } - // We don't know if a whiteout exists on the upper layer; speculatively - // unlink it. - // - // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do - // know whether a whiteout exists. - var haveUpperWhiteout bool - switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err { - case nil: - haveUpperWhiteout = true - case syserror.ENOENT: - haveUpperWhiteout = false - default: - return nil, err + // Unlink the whiteout if it exists. + if haveUpperWhiteout { + if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil { + log.Warningf("overlay.filesystem.createAndOpenLocked: failed to unlink whiteout: %v", err) + return nil, err + } } // Create the file on the upper layer, and get an FD representing it. upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{ @@ -961,7 +979,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } // Re-lookup to get a dentry representing the new file, which is needed for // the returned FD. - child, err := fs.getChildLocked(ctx, parent, childName, ds) + child, _, err := fs.getChildLocked(ctx, parent, childName, ds) if err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)) @@ -970,7 +988,10 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } - // Finally construct the overlay FD. + // Finally construct the overlay FD. Below this point, we don't perform + // cleanup (the file was created successfully even if we can no longer open + // it for some reason). + parent.dirents = nil upperFlags := upperFD.StatusFlags() fd := ®ularFileFD{ copiedUp: true, @@ -981,8 +1002,6 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving upperFDOpts := upperFD.Options() if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil { upperFD.DecRef(ctx) - // Don't bother with cleanup; the file was created successfully, we - // just can't open it anymore for some reason. return nil, err } parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */) @@ -1040,7 +1059,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) + renamed, _, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } @@ -1072,20 +1091,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.vfsd.IsDead() { return syserror.ENOENT } - replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName) - if err != nil { - return err - } var ( - replaced *dentry - replacedVFSD *vfs.Dentry - whiteouts map[string]bool + replaced *dentry + replacedVFSD *vfs.Dentry + replacedLayer lookupLayer + whiteouts map[string]bool ) - if replacedLayer.existsInOverlay() { - replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil { - return err - } + replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { + return err + } + if replaced != nil { replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { @@ -1289,7 +1305,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // Unlike UnlinkAt, we need a dentry representing the child directory being // removed in order to verify that it's empty. - child, err := fs.getChildLocked(ctx, parent, name, &ds) + child, _, err := fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } @@ -1541,7 +1557,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if parentMode&linux.S_ISVTX != 0 { // If the parent's sticky bit is set, we need a child dentry to get // its owner. - child, err = fs.getChildLocked(ctx, parent, name, &ds) + child, _, err = fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 4c5de8d32..3492409b2 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -22,6 +22,7 @@ // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// filesystem.devMu // *** "memmap.Mappable locks" below this point // dentry.mapsMu // *** "memmap.Mappable locks taken by Translate" below this point @@ -33,12 +34,14 @@ package overlay import ( + "fmt" "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -99,10 +102,15 @@ type filesystem struct { // is immutable. dirDevMinor uint32 - // lowerDevMinors maps lower layer filesystems to device minor numbers - // assigned to non-directory files originating from that filesystem. - // lowerDevMinors is immutable. - lowerDevMinors map[*vfs.Filesystem]uint32 + // lowerDevMinors maps device numbers from lower layer filesystems to + // device minor numbers assigned to non-directory files originating from + // that filesystem. (This remapping is necessary for lower layers because a + // file on a lower layer, and that same file on an overlay, are + // distinguishable because they will diverge after copy-up; this isn't true + // for non-directory files already on the upper layer.) lowerDevMinors is + // protected by devMu. + devMu sync.Mutex `state:"nosave"` + lowerDevMinors map[layerDevNumber]uint32 // renameMu synchronizes renaming with non-renaming operations in order to // ensure consistent lock ordering between dentry.dirMu in different @@ -114,78 +122,69 @@ type filesystem struct { lastDirIno uint64 } +// +stateify savable +type layerDevNumber struct { + major uint32 + minor uint32 +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { mopts := vfs.GenericParseMountOptions(opts.Data) fsoptsRaw := opts.InternalData - fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) - if fsoptsRaw != nil && !haveFSOpts { + fsopts, ok := fsoptsRaw.(FilesystemOptions) + if fsoptsRaw != nil && !ok { ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } - if haveFSOpts { - if len(fsopts.LowerRoots) == 0 { - ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + + if upperPathname, ok := mopts["upperdir"]; ok { + if fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified") return nil, nil, syserror.EINVAL } - if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + delete(mopts, "upperdir") + // Linux overlayfs also requires a workdir when upperdir is + // specified; we don't, so silently ignore this option. + delete(mopts, "workdir") + upperPath := fspath.Parse(upperPathname) + if !upperPath.Absolute { + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } - // We don't enforce a maximum number of lower layers when not - // configured by applications; the sandbox owner can have an overlay - // filesystem with any number of lower layers. - } else { - vfsroot := vfs.RootFromContext(ctx) - defer vfsroot.DecRef(ctx) - upperPathname, ok := mopts["upperdir"] - if ok { - delete(mopts, "upperdir") - // Linux overlayfs also requires a workdir when upperdir is - // specified; we don't, so silently ignore this option. - delete(mopts, "workdir") - upperPath := fspath.Parse(upperPathname) - if !upperPath.Absolute { - ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) - return nil, nil, syserror.EINVAL - } - upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ - Root: vfsroot, - Start: vfsroot, - Path: upperPath, - FollowFinalSymlink: true, - }, &vfs.GetDentryOptions{ - CheckSearchable: true, - }) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer upperRoot.DecRef(ctx) - privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer privateUpperRoot.DecRef(ctx) - fsopts.UpperRoot = privateUpperRoot + upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: upperPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + return nil, nil, err + } + privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) + upperRoot.DecRef(ctx) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + return nil, nil, err } - lowerPathnamesStr, ok := mopts["lowerdir"] - if !ok { - ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + defer privateUpperRoot.DecRef(ctx) + fsopts.UpperRoot = privateUpperRoot + } + + if lowerPathnamesStr, ok := mopts["lowerdir"]; ok { + if len(fsopts.LowerRoots) != 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") - const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK - if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") - return nil, nil, syserror.EINVAL - } - if len(lowerPathnames) > maxLowerLayers { - ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) - return nil, nil, syserror.EINVAL - } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { @@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) + lowerRoot.DecRef(ctx) if err != nil { ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err @@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot) } } + if len(mopts) != 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } - // Allocate device numbers. + if len(fsopts.LowerRoots) == 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required") + return nil, nil, syserror.EINVAL + } + if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present") + return nil, nil, syserror.EINVAL + } + const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK + if len(fsopts.LowerRoots) > maxLowerLayers { + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers) + return nil, nil, syserror.EINVAL + } + + // Allocate dirDevMinor. lowerDevMinors are allocated dynamically. dirDevMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } - lowerDevMinors := make(map[*vfs.Filesystem]uint32) - for _, lowerRoot := range fsopts.LowerRoots { - lowerFS := lowerRoot.Mount().Filesystem() - if _, ok := lowerDevMinors[lowerFS]; !ok { - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - vfsObj.PutAnonBlockDevMinor(dirDevMinor) - for _, lowerDevMinor := range lowerDevMinors { - vfsObj.PutAnonBlockDevMinor(lowerDevMinor) - } - return nil, nil, err - } - lowerDevMinors[lowerFS] = devMinor - } - } // Take extra references held by the filesystem. if fsopts.UpperRoot.Ok() { @@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt opts: fsopts, creds: creds.Fork(), dirDevMinor: dirDevMinor, - lowerDevMinors: lowerDevMinors, + lowerDevMinors: make(map[layerDevNumber]uint32), } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt root.ino = fs.newDirIno() } else if !root.upperVD.Ok() { root.devMajor = linux.UNNAMED_MAJOR - root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()] + rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err) + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) + return nil, nil, err + } + root.devMinor = rootDevMinor root.ino = rootStat.Ino } else { root.devMajor = rootStat.DevMajor @@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 { return atomic.AddUint64(&fs.lastDirIno, 1) } +func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) { + fs.devMu.Lock() + defer fs.devMu.Unlock() + orig := layerDevNumber{layerMajor, layerMinor} + if minor, ok := fs.lowerDevMinors[orig]; ok { + return minor, nil + } + minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor() + if err != nil { + return 0, err + } + fs.lowerDevMinors[orig] = minor + return minor, nil +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -458,9 +479,9 @@ type dentry struct { // // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is // accessed using atomic memory operations. - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` lowerMappings memmap.MappingSet - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` wrappedMappable memmap.Mappable isMappable uint32 @@ -484,6 +505,7 @@ func (fs *filesystem) newDentry() *dentry { } d.lowerVDs = d.inlineLowerVDs[:0] d.vfsd.Init(d) + refsvfs2.Register(d) return d } @@ -491,17 +513,23 @@ func (fs *filesystem) newDentry() *dentry { func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkDropLocked(). - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } return true } } @@ -509,15 +537,31 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + r := atomic.AddInt64(&d.refs, -1) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } + if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { + } else if r < 0 { panic("overlay.dentry.DecRef() called without holding a reference") } } +func (d *dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } + if r == 0 { + d.checkDropLocked(ctx) + } else if r < 0 { + panic("overlay.dentry.decRefLocked() called without holding a reference") + } +} + // checkDropLocked should be called after d's reference count becomes 0 or it // becomes deleted. // @@ -577,12 +621,27 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.parent.dirMu.Unlock() // Drop the reference held by d on its parent without recursively // locking d.fs.renameMu. - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked(ctx) - } else if refs < 0 { - panic("overlay.dentry.DecRef() called without holding a reference") - } + d.parent.decRefLocked(ctx) } + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "overlay.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -645,6 +704,13 @@ func (d *dentry) topLayer() vfs.VirtualDentry { return vd } +func (d *dentry) topLookupLayer() lookupLayer { + if d.upperVD.Ok() { + return lookupLayerUpper + } + return lookupLayerLower +} + func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 2b89a7a6d..25c785fd4 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript for e, mask := range fd.lowerWaiters { fd.cachedFD.EventUnregister(e) upperFD.EventRegister(e, mask) - if ready&mask != 0 { - e.Callback.Callback(e) + if m := ready & mask; m != 0 { + e.Callback.Callback(e, m) } } } diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/sentry/fsimpl/overlay/save_restore.go index 19d6b0b15..54809f16c 100644 --- a/pkg/syncevent/waiter_asm_unsafe.go +++ b/pkg/sentry/fsimpl/overlay/save_restore.go @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 arm64 - -package syncevent +package overlay import ( - "unsafe" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" ) -// See waiter_noasm_unsafe.go for a description of waiterUnlock. -func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool +func (d *dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index e44b79b68..0ecb592cf 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -101,7 +101,7 @@ type inode struct { func newInode(ctx context.Context, fs *filesystem) *inode { creds := auth.CredentialsFromContext(ctx) return &inode{ - pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize), ino: fs.Filesystem.NextIno(), uid: creds.EffectiveKUID, gid: creds.EffectiveKGID, diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 2e086e34c..5196a2a80 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "fd_dir_inode_refs.go", package = "proc", prefix = "fdDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdDirInode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "fd_info_dir_inode_refs.go", package = "proc", prefix = "fdInfoDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdInfoDirInode", }, @@ -30,7 +30,7 @@ go_template_instance( out = "subtasks_inode_refs.go", package = "proc", prefix = "subtasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "subtasksInode", }, @@ -41,7 +41,7 @@ go_template_instance( out = "task_inode_refs.go", package = "proc", prefix = "taskInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "taskInode", }, @@ -52,7 +52,7 @@ go_template_instance( out = "tasks_inode_refs.go", package = "proc", prefix = "tasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tasksInode", }, @@ -82,6 +82,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index fd70a07de..8716d0a3c 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -17,6 +17,7 @@ package proc import ( "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -24,10 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "proc" +const ( + // Name is the default filesystem name. + Name = "proc" + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType is the factory class for procfs. // @@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF if err != nil { return nil, nil, err } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + procfs := &filesystem{ devMinor: devMinor, } + procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) var cgroups map[string]string @@ -74,9 +92,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF cgroups = data.Cgroups } - inode := procfs.newTasksInode(k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, cgroups) var dentry kernfs.Dentry - dentry.Init(&procfs.Filesystem, inode) + dentry.InitRoot(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil } @@ -94,11 +112,11 @@ type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource - Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) + Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) } -func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { - inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) +func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) return inode } @@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile { return &staticFile{StaticData: vfs.StaticData{Data: data}} } -func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { - return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ +func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { + return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) } diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index bad2fab4f..c53cc0122 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -50,7 +50,7 @@ type subtasksInode struct { var _ kernfs.Inode = (*subtasksInode)(nil) -func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) kernfs.Inode { +func (fs *filesystem) newSubtasks(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) kernfs.Inode { subInode := &subtasksInode{ fs: fs, task: task, @@ -58,9 +58,9 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers: cgroupControllers, } // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + subInode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - subInode.EnableLeakCheck() + subInode.InitRefs() inode := &taskOwnedInode{Inode: subInode, owner: task} return inode @@ -80,11 +80,11 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, if subTask.ThreadGroup() != i.task.ThreadGroup() { return nil, syserror.ENOENT } - return i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers) + return i.fs.newTaskInode(ctx, subTask, i.pidns, false, i.cgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { return offset, syserror.ENOENT diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index b63a4eca0..fea138f93 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -47,50 +47,51 @@ type taskInode struct { var _ kernfs.Inode = (*taskInode)(nil) -func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { +func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { if task.ExitState() == kernel.TaskExitDead { return nil, syserror.ESRCH } contents := map[string]kernfs.Inode{ - "auxv": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &auxvData{task: task}), - "cmdline": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), - "comm": fs.newComm(task, fs.NextIno(), 0444), - "cwd": fs.newCwdSymlink(task, fs.NextIno()), - "environ": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), - "exe": fs.newExeSymlink(task, fs.NextIno()), - "fd": fs.newFDDirInode(task), - "fdinfo": fs.newFDInfoDirInode(task), - "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), - "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), - "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}), - "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}), - "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}), - "net": fs.newTaskNetDir(task), - "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]kernfs.Inode{ - "net": fs.newNamespaceSymlink(task, fs.NextIno(), "net"), - "pid": fs.newNamespaceSymlink(task, fs.NextIno(), "pid"), - "user": fs.newNamespaceSymlink(task, fs.NextIno(), "user"), + "auxv": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &auxvData{task: task}), + "cmdline": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), + "comm": fs.newComm(ctx, task, fs.NextIno(), 0444), + "cwd": fs.newCwdSymlink(ctx, task, fs.NextIno()), + "environ": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), + "exe": fs.newExeSymlink(ctx, task, fs.NextIno()), + "fd": fs.newFDDirInode(ctx, task), + "fdinfo": fs.newFDInfoDirInode(ctx, task), + "gid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), + "io": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), + "maps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &mapsData{task: task}), + "mem": fs.newMemInode(ctx, task, fs.NextIno(), 0400), + "mountinfo": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &mountInfoData{task: task}), + "mounts": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &mountsData{task: task}), + "net": fs.newTaskNetDir(ctx, task), + "ns": fs.newTaskOwnedDir(ctx, task, fs.NextIno(), 0511, map[string]kernfs.Inode{ + "net": fs.newNamespaceSymlink(ctx, task, fs.NextIno(), "net"), + "pid": fs.newNamespaceSymlink(ctx, task, fs.NextIno(), "pid"), + "user": fs.newNamespaceSymlink(ctx, task, fs.NextIno(), "user"), }), - "oom_score": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newStaticFile("0\n")), - "oom_score_adj": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}), - "smaps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &smapsData{task: task}), - "stat": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), - "statm": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statmData{task: task}), - "status": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}), - "uid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), + "oom_score": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newStaticFile("0\n")), + "oom_score_adj": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &oomScoreAdj{task: task}), + "smaps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &smapsData{task: task}), + "stat": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), + "statm": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statmData{task: task}), + "status": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}), + "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { - contents["task"] = fs.newSubtasks(task, pidns, cgroupControllers) + contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers) } if len(cgroupControllers) > 0 { - contents["cgroup"] = fs.newTaskOwnedInode(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) } taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - taskInode.EnableLeakCheck() + taskInode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.InitRefs() inode := &taskOwnedInode{Inode: taskInode, owner: task} @@ -142,17 +143,17 @@ type taskOwnedInode struct { var _ kernfs.Inode = (*taskOwnedInode)(nil) -func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { +func (fs *filesystem) newTaskOwnedInode(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) return &taskOwnedInode{Inode: inode, owner: task} } -func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newTaskOwnedDir(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero} - dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) + dir := kernfs.NewStaticDir(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) return &taskOwnedInode{Inode: dir, owner: task} } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 2c80ac5c2..02bf74dbc 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -64,7 +64,7 @@ type fdDir struct { } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { @@ -119,7 +119,7 @@ type fdDirInode struct { var _ kernfs.Inode = (*fdDirInode)(nil) -func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { +func (fs *filesystem) newFDDirInode(ctx context.Context, task *kernel.Task) kernfs.Inode { inode := &fdDirInode{ fdDir: fdDir{ fs: fs, @@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { produceSymlink: true, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - inode.EnableLeakCheck() + inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Lookup implements kernfs.inodeDirectory.Lookup. @@ -148,7 +148,7 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err if !taskFDExists(ctx, i.task, fd) { return nil, syserror.ENOENT } - return i.fs.newFDSymlink(i.task, fd, i.fs.NextIno()), nil + return i.fs.newFDSymlink(ctx, i.task, fd, i.fs.NextIno()), nil } // Open implements kernfs.Inode.Open. @@ -204,12 +204,12 @@ type fdSymlink struct { var _ kernfs.Inode = (*fdSymlink)(nil) -func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kernfs.Inode { +func (fs *filesystem) newFDSymlink(ctx context.Context, task *kernel.Task, fd int32, ino uint64) kernfs.Inode { inode := &fdSymlink{ task: task, fd: fd, } - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -257,15 +257,15 @@ type fdInfoDirInode struct { var _ kernfs.Inode = (*fdInfoDirInode)(nil) -func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode { +func (fs *filesystem) newFDInfoDirInode(ctx context.Context, task *kernel.Task) kernfs.Inode { inode := &fdInfoDirInode{ fdDir: fdDir{ fs: fs, task: task, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - inode.EnableLeakCheck() + inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } @@ -284,12 +284,12 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, task: i.task, fd: fd, } - return i.fs.newTaskOwnedInode(i.task, i.fs.NextIno(), 0444, data), nil + return i.fs.newTaskOwnedInode(ctx, i.task, i.fs.NextIno(), 0444, data), nil } // IterDirents implements Inode.IterDirents. -func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 79f8b7e9f..a3780b222 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -247,9 +248,9 @@ type commInode struct { task *kernel.Task } -func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { +func (fs *filesystem) newComm(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) + inode.DynamicBytesFile.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) return inode } @@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in return int64(srclen), nil } +var _ kernfs.Inode = (*memInode)(nil) + +// memInode implements kernfs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memInode struct { + kernfs.InodeAttrs + kernfs.InodeNoStatFS + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + task *kernel.Task + locks vfs.FileLocks +} + +func (fs *filesystem) newMemInode(ctx context.Context, task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { + // Note: credentials are overridden by taskOwnedInode. + inode := &memInode{task: task} + inode.init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + return &taskOwnedInode{Inode: inode, owner: task} +} + +func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) + } + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) +} + +// Open implements kernfs.Inode.Open. +func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, f.task, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(f.task); err != nil { + return nil, err + } + fd := &memFD{} + if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +var _ vfs.FileDescriptionImpl = (*memFD)(nil) + +// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem. +// +// +stateify savable +type memFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + inode *memInode + + // mu guards the fields below. + mu sync.Mutex `state:"nosave"` + offset int64 +} + +// Init initializes memFD. +func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error { + fd.LockFD.Init(&inode.locks) + if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return err + } + fd.inode = inode + return nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + switch whence { + case linux.SEEK_SET: + case linux.SEEK_CUR: + offset += fd.offset + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.offset = offset + return offset, nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + m, err := getMMIncRef(fd.inode.task) + if err != nil { + return 0, nil + } + defer m.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.offset, opts) + fd.offset += n + fd.mu.Unlock() + return n, err +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *memFD) Release(context.Context) {} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} + // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -655,9 +812,9 @@ type exeSymlink struct { var _ kernfs.Inode = (*exeSymlink)(nil) -func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode { +func (fs *filesystem) newExeSymlink(ctx context.Context, task *kernel.Task, ino uint64) kernfs.Inode { inode := &exeSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -731,9 +888,9 @@ type cwdSymlink struct { var _ kernfs.Inode = (*cwdSymlink)(nil) -func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode { +func (fs *filesystem) newCwdSymlink(ctx context.Context, task *kernel.Task, ino uint64) kernfs.Inode { inode := &cwdSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -842,7 +999,7 @@ type namespaceSymlink struct { task *kernel.Task } -func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) kernfs.Inode { +func (fs *filesystem) newNamespaceSymlink(ctx context.Context, task *kernel.Task, ino uint64, ns string) kernfs.Inode { // Namespace symlinks should contain the namespace name and the inode number // for the namespace instance, so for example user:[123456]. We currently fake // the inode number by sticking the symlink inode in its place. @@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri inode := &namespaceSymlink{task: task} // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) + inode.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) taskInode := &taskOwnedInode{Inode: inode, owner: task} return taskInode @@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir // Create a synthetic inode to represent the namespace. fs := mnt.Filesystem().Impl().(*filesystem) + nsInode := &namespaceInode{} + nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444) dentry := &kernfs.Dentry{} - dentry.Init(&fs.Filesystem, &namespaceInode{}) + dentry.Init(&fs.Filesystem, nsInode) vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry()) // Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1. mnt.IncRef() @@ -897,11 +1056,11 @@ type namespaceInode struct { var _ kernfs.Inode = (*namespaceInode)(nil) // Init initializes a namespace inode. -func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 3425e8698..d4f6a5a9b 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -37,7 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { +func (fs *filesystem) newTaskNetDir(ctx context.Context, task *kernel.Task) kernfs.Inode { k := task.Kernel() pidns := task.PIDNamespace() root := auth.NewRootCredentials(pidns.UserNamespace()) @@ -57,37 +57,37 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]kernfs.Inode{ - "dev": fs.newInode(root, 0444, &netDevData{stack: stack}), - "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}), + "dev": fs.newInode(ctx, root, 0444, &netDevData{stack: stack}), + "snmp": fs.newInode(ctx, root, 0444, &netSnmpData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, if the file contains a header the stub is just the header // otherwise it is an empty file. - "arp": fs.newInode(root, 0444, newStaticFile(arp)), - "netlink": fs.newInode(root, 0444, newStaticFile(netlink)), - "netstat": fs.newInode(root, 0444, &netStatData{}), - "packet": fs.newInode(root, 0444, newStaticFile(packet)), - "protocols": fs.newInode(root, 0444, newStaticFile(protocols)), + "arp": fs.newInode(ctx, root, 0444, newStaticFile(arp)), + "netlink": fs.newInode(ctx, root, 0444, newStaticFile(netlink)), + "netstat": fs.newInode(ctx, root, 0444, &netStatData{}), + "packet": fs.newInode(ctx, root, 0444, newStaticFile(packet)), + "protocols": fs.newInode(ctx, root, 0444, newStaticFile(protocols)), // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": fs.newInode(root, 0444, newStaticFile(psched)), - "ptype": fs.newInode(root, 0444, newStaticFile(ptype)), - "route": fs.newInode(root, 0444, &netRouteData{stack: stack}), - "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}), - "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}), - "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}), + "psched": fs.newInode(ctx, root, 0444, newStaticFile(psched)), + "ptype": fs.newInode(ctx, root, 0444, newStaticFile(ptype)), + "route": fs.newInode(ctx, root, 0444, &netRouteData{stack: stack}), + "tcp": fs.newInode(ctx, root, 0444, &netTCPData{kernel: k}), + "udp": fs.newInode(ctx, root, 0444, &netUDPData{kernel: k}), + "unix": fs.newInode(ctx, root, 0444, &netUnixData{kernel: k}), } if stack.SupportsIPv6() { - contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile("")) - contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6)) + contents["if_inet6"] = fs.newInode(ctx, root, 0444, &ifinet6{stack: stack}) + contents["ipv6_route"] = fs.newInode(ctx, root, 0444, newStaticFile("")) + contents["tcp6"] = fs.newInode(ctx, root, 0444, &netTCP6Data{kernel: k}) + contents["udp6"] = fs.newInode(ctx, root, 0444, newStaticFile(upd6)) } } - return fs.newTaskOwnedDir(task, fs.NextIno(), 0555, contents) + return fs.newTaskOwnedDir(ctx, task, fs.NextIno(), 0555, contents) } // ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6. @@ -208,7 +208,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { for _, se := range n.kernel.ListSockets() { s := se.SockVFS2 if !s.TryIncRef() { - log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) + // Racing with socket destruction, this is ok. continue } if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX { @@ -351,7 +351,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, for _, se := range k.ListSockets() { s := se.SockVFS2 if !s.TryIncRef() { - log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) + // Racing with socket destruction, this is ok. continue } sops, ok := s.Impl().(socket.SocketVFS2) @@ -516,7 +516,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { for _, se := range d.kernel.ListSockets() { s := se.SockVFS2 if !s.TryIncRef() { - log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) + // Racing with socket destruction, this is ok. continue } sops, ok := s.Impl().(socket.SocketVFS2) diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 3259c3732..fdc580610 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -62,19 +62,19 @@ type tasksInode struct { var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ - "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))), - "filesystems": fs.newInode(root, 0444, &filesystemsData{}), - "loadavg": fs.newInode(root, 0444, &loadavgData{}), - "sys": fs.newSysDir(root, k), - "meminfo": fs.newInode(root, 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), - "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), - "stat": fs.newInode(root, 0444, &statData{}), - "uptime": fs.newInode(root, 0444, &uptimeData{}), - "version": fs.newInode(root, 0444, &versionData{}), + "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), + "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), + "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), + "sys": fs.newSysDir(ctx, root, k), + "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}), + "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), + "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), + "stat": fs.newInode(ctx, root, 0444, &statData{}), + "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), + "version": fs.newInode(ctx, root, 0444, &versionData{}), } inode := &tasksInode{ @@ -82,8 +82,8 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace fs: fs, cgroupControllers: cgroupControllers, } - inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - inode.EnableLeakCheck() + inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) links := inode.OrderedChildren.Populate(contents) @@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err // If it failed to parse, check if it's one of the special handled files. switch name { case selfName: - return i.newSelfSymlink(root), nil + return i.newSelfSymlink(ctx, root), nil case threadSelfName: - return i.newThreadSelfSymlink(root), nil + return i.newThreadSelfSymlink(ctx, root), nil } return nil, syserror.ENOENT } @@ -118,11 +118,11 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err return nil, syserror.ENOENT } - return i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers) + return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { +func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 07c27cdd9..01b7a6678 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -43,9 +43,9 @@ type selfSymlink struct { var _ kernfs.Inode = (*selfSymlink)(nil) -func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &selfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } @@ -84,9 +84,9 @@ type threadSelfSymlink struct { var _ kernfs.Inode = (*threadSelfSymlink)(nil) -func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &threadSelfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 95420368d..25c407d98 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -40,93 +40,94 @@ const ( ) // newSysDir returns the dentry corresponding to /proc/sys directory. -func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { - return fs.newStaticDir(root, map[string]kernfs.Inode{ - "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{ - "hostname": fs.newInode(root, 0444, &hostnameData{}), - "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)), - "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)), - "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)), +func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { + return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), + "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), + "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), + "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), }), - "vm": fs.newStaticDir(root, map[string]kernfs.Inode{ - "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}), - "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")), + "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}), + "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")), }), - "net": fs.newSysNetDir(root, k), + "net": fs.newSysNetDir(ctx, root, k), }) } // newSysNetDir returns the dentry corresponding to /proc/sys/net directory. -func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { +func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { var contents map[string]kernfs.Inode // TODO(gvisor.dev/issue/1833): Support for using the network stack in the // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ - "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}), + "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")), - "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")), - "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")), - "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")), - "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")), + "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), + "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), + "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), + "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")), // tcp_allowed_congestion_control tell the user what they are able to // do as an unprivledged process so we leave it empty. - "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")), - "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), - "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), + "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), + "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), // Many of the following stub files are features netstack doesn't // support. The unsupported features return "0" to indicate they are // disabled. - "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")), - "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")), - "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")), - "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")), - "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")), - "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")), + "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")), + "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")), + "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")), + "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")), }), - "core": fs.newStaticDir(root, map[string]kernfs.Inode{ - "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")), - "message_burst": fs.newInode(root, 0444, newStaticFile("10")), - "message_cost": fs.newInode(root, 0444, newStaticFile("5")), - "optmem_max": fs.newInode(root, 0444, newStaticFile("0")), - "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")), - "somaxconn": fs.newInode(root, 0444, newStaticFile("128")), - "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")), + "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")), + "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")), + "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")), + "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), }), } } - return fs.newStaticDir(root, contents) + return fs.newStaticDir(ctx, root, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 2582ababd..7ee6227a9 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -77,6 +77,7 @@ var ( "gid_map": linux.DT_REG, "io": linux.DT_REG, "maps": linux.DT_REG, + "mem": linux.DT_REG, "mountinfo": linux.DT_REG, "mounts": linux.DT_REG, "net": linux.DT_DIR, diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 10f1452ef..246bd87bc 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package signalfd provides basic signalfd file implementations. package signalfd import ( @@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index cf91ea36c..fda1fa942 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // NewDentry constructs and returns a sockfs dentry. // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry { +func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry { fs := mnt.Filesystem().Impl().(*filesystem) // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) i := &inode{} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) + i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) d := &kernfs.Dentry{} d.Init(&fs.Filesystem, i) diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 906cd52cb..09043b572 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "dir_refs.go", package = "sys", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -28,6 +28,7 @@ go_library( "//pkg/coverage", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index 31a361029..b13f141a8 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -29,7 +29,7 @@ import ( func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode { k := &kcovInode{} - k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) + k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) return k } diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1ad679830..dbd9ebdda 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -18,10 +18,12 @@ package sys import ( "bytes" "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/coverage" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -29,9 +31,12 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "sysfs" -const defaultSysDirMode = linux.FileMode(0755) +const ( + // Name is the default filesystem name. + Name = "sysfs" + defaultSysDirMode = linux.FileMode(0755) + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType implements vfs.FilesystemType. // @@ -62,31 +67,43 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + fs := &filesystem{ devMinor: devMinor, } + fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "block": fs.newDir(creds, defaultSysDirMode, nil), - "bus": fs.newDir(creds, defaultSysDirMode, nil), - "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "power_supply": fs.newDir(creds, defaultSysDirMode, nil), + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil), }), - "dev": fs.newDir(creds, defaultSysDirMode, nil), - "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ + "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "cpu": cpuDir(ctx, fs, creds), }), }), - "firmware": fs.newDir(creds, defaultSysDirMode, nil), - "fs": fs.newDir(creds, defaultSysDirMode, nil), + "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), "kernel": kernelDir(ctx, fs, creds), - "module": fs.newDir(creds, defaultSysDirMode, nil), - "power": fs.newDir(creds, defaultSysDirMode, nil), + "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), }) var rootD kernfs.Dentry - rootD.Init(&fs.Filesystem, root) + rootD.InitRoot(&fs.Filesystem, root) return fs.VFSFilesystem(), rootD.VFSDentry(), nil } @@ -94,14 +111,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs k := kernel.KernelFromContext(ctx) maxCPUCores := k.ApplicationCores() children := map[string]kernfs.Inode{ - "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), } for i := uint(0); i < maxCPUCores; i++ { - children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil) + children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil) } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { @@ -110,13 +127,14 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker // keep it in sys. var children map[string]kernfs.Inode if coverage.KcovAvailable() { + log.Debugf("Set up /sys/kernel/debug/kcov") children = map[string]kernfs.Inode{ - "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{ + "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ "kcov": fs.newKcovFile(ctx, creds), }), } } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } // Release implements vfs.FilesystemImpl.Release. @@ -140,11 +158,11 @@ type dir struct { locks vfs.FileLocks } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { d := &dir{} - d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - d.EnableLeakCheck() + d.InitRefs() d.IncLinks(d.OrderedChildren.Populate(contents)) return d } @@ -191,9 +209,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { +func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { c := &cpuFile{maxCores: maxCores} - c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) + c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) return c } diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 738c0c9cc..205ad8192 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -136,7 +136,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns config := &kernel.TaskConfig{ Kernel: k, ThreadGroup: tc, - TaskContext: &kernel.TaskContext{Name: name, MemoryManager: m}, + TaskImage: &kernel.TaskImage{Name: name, MemoryManager: m}, Credentials: auth.CredentialsFromContext(ctx), NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 5cd428d64..09957c2b7 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -31,7 +31,7 @@ go_template_instance( out = "inode_refs.go", package = "tmpfs", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -48,6 +48,7 @@ go_library( "inode_refs.go", "named_pipe.go", "regular_file.go", + "save_restore.go", "socket_file.go", "symlink.go", "tmpfs.go", @@ -60,11 +61,13 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e39cd305b..9296db2fb 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -381,6 +382,8 @@ afterTrailingSymlink: creds := rp.Credentials() child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)) parentDir.insertChildLocked(child, name) + child.IncRef() + defer child.DecRef(ctx) unlock() fd, err := child.open(ctx, rp, &opts, true) if err != nil { @@ -437,6 +440,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open return nil, err } } + if fd.vfsfd.IsWritable() { + fsmetric.TmpfsOpensW.Increment() + } else if fd.vfsfd.IsReadable() { + fsmetric.TmpfsOpensRO.Increment() + } return &fd.vfsfd, nil case *directory: // Can't open directories writably. diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index d772db9e9..57e7b57b0 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable @@ -32,7 +31,7 @@ type namedPipe struct { // * fs.mu must be locked. // * rp.Mount().CheckBeginWrite() has been called successfully. func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { - file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} + file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize)} file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode) file.inode.nlink = 1 // Only the parent has a link. return &file.inode diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index ce4e3eda7..6255a7c84 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -42,7 +43,7 @@ type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. - memFile *pgalloc.MemoryFile + memFile *pgalloc.MemoryFile `state:"nosave"` // memoryUsageKind is the memory accounting category under which pages backing // this regularFile's contents are accounted. @@ -92,7 +93,7 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, + memFile: fs.mfp.MemoryFile(), memoryUsageKind: usage.Tmpfs, seals: linux.F_SEAL_SEAL, } @@ -359,6 +360,10 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start) + fsmetric.TmpfsReads.Increment() + if offset < 0 { return 0, syserror.EINVAL } @@ -565,7 +570,7 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. // -// Preconditions: inode.mu must be held. +// Preconditions: rw.file.inode.mu must be held. func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { // Hold dataMu so we can modify size. rw.file.dataMu.Lock() @@ -657,7 +662,7 @@ exitLoop: // If the write ends beyond the file's previous size, it causes the // file to grow. if rw.off > rw.file.size { - rw.file.size = rw.off + atomic.StoreUint64(&rw.file.size, rw.off) } return done, retErr diff --git a/pkg/sleep/commit_asm.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go index 75728a97d..b27f75cc2 100644 --- a/pkg/sleep/commit_asm.go +++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build amd64 arm64 +package tmpfs -package sleep - -// See commit_noasm.go for a description of commitSleep. -func commitSleep(g uintptr, waitingG *uintptr) bool +// afterLoad is called by stateify. +func (rf *regularFile) afterLoad() { + rf.memFile = rf.inode.fs.mfp.MemoryFile() +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index e2a0aac69..0c9c639d3 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -61,8 +61,9 @@ type FilesystemType struct{} type filesystem struct { vfsfs vfs.Filesystem - // memFile is used to allocate pages to for regular files. - memFile *pgalloc.MemoryFile + // mfp is used to allocate memory that stores regular file contents. mfp is + // immutable. + mfp pgalloc.MemoryFileProvider // clock is a realtime clock used to set timestamps in file operations. clock time.Clock @@ -106,8 +107,8 @@ type FilesystemOpts struct { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) - if memFileProvider == nil { + mfp := pgalloc.MemoryFileProviderFromContext(ctx) + if mfp == nil { panic("MemoryFileProviderFromContext returned nil") } @@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } clock := time.RealtimeClockFromContext(ctx) fs := filesystem{ - memFile: memFileProvider.MemoryFile(), + mfp: mfp, clock: clock, devMinor: devMinor, } @@ -401,7 +402,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth i.mtime = now // i.nlink initialized by caller i.impl = impl - i.refs.EnableLeakCheck() + i.refs.InitRefs() } // incLinksLocked increments i's link count. @@ -477,9 +478,9 @@ func (i *inode) statTo(stat *linux.Statx) { stat.GID = atomic.LoadUint32(&i.gid) stat.Mode = uint16(atomic.LoadUint32(&i.mode)) stat.Ino = i.ino - stat.Atime = linux.NsecToStatxTimestamp(i.atime) - stat.Ctime = linux.NsecToStatxTimestamp(i.ctime) - stat.Mtime = linux.NsecToStatxTimestamp(i.mtime) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&i.atime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&i.ctime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&i.mtime)) stat.DevMajor = linux.UNNAMED_MAJOR stat.DevMinor = i.fs.devMinor switch impl := i.impl.(type) { @@ -630,7 +631,8 @@ func (i *inode) direntType() uint8 { } func (i *inode) isDir() bool { - return linux.FileMode(i.mode).FileType() == linux.S_IFDIR + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + return mode.FileType() == linux.S_IFDIR } func (i *inode) touchAtime(mnt *vfs.Mount) { diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 0ca750281..e265be0ee 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -6,6 +6,7 @@ go_library( name = "verity", srcs = [ "filesystem.go", + "save_restore.go", "verity.go", ], visibility = ["//pkg/sentry:internal"], @@ -15,6 +16,7 @@ go_library( "//pkg/fspath", "//pkg/marshal/primitive", "//pkg/merkletree", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", @@ -38,10 +40,12 @@ go_test( "//pkg/context", "//pkg/fspath", "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", "//pkg/sentry/vfs", + "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 03da505e1..04e7110a3 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -16,6 +16,7 @@ package verity import ( "bytes" + "encoding/json" "fmt" "io" "strconv" @@ -31,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -105,8 +107,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // Dentries which may have a reference count of zero, and which therefore // should be dropped once traversal is complete, are appended to ds. // -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !rp.Done(). func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { if !d.isDir() { return nil, syserror.ENOTDIR @@ -156,15 +160,19 @@ afterSymlink: return child, nil } -// verifyChild verifies the hash of child against the already verified hash of -// the parent to ensure the child is expected. verifyChild triggers a sentry -// panic if unexpected modifications to the file system are detected. In +// verifyChildLocked verifies the hash of child against the already verified +// hash of the parent to ensure the child is expected. verifyChild triggers a +// sentry panic if unexpected modifications to the file system are detected. In // noCrashOnVerificationFailure mode it returns a syserror instead. -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// // TODO(b/166474175): Investigate all possible errors returned in this // function, and make sure we differentiate all errors that indicate unexpected // modifications to the file system from the ones that are not harmful. -func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { +func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() // Get the path to the child dentry. This is only used to provide path @@ -192,7 +200,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -201,7 +209,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -215,7 +223,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -233,7 +241,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -243,10 +251,10 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: parentMerkleFD, Ctx: ctx, } @@ -256,7 +264,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -266,33 +274,44 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contain the hash of the children in the parent Merkle tree when // Verify returns with success. var buf bytes.Buffer - if _, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, + parent.hashMu.RLock() + _, err = merkletree.Verify(&merkletree.VerifyParams{ + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + Children: parent.childrenNames, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), - ReadSize: int64(merkletree.DigestSize()), + ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), Expected: parent.hash, DataAndTreeInSameFile: true, - }); err != nil && err != io.EOF { - return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + }) + parent.hashMu.RUnlock() + if err != nil && err != io.EOF { + return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. + child.hashMu.Lock() if len(child.hash) == 0 { child.hash = buf.Bytes() } + child.hashMu.Unlock() return child, nil } -// verifyStat verifies the stat against the verified hash. The mode/uid/gid of -// the file is cached after verified. -func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Statx) error { +// verifyStatAndChildrenLocked verifies the stat and children names against the +// verified hash. The mode/uid/gid and childrenNames of the file is cached +// after verified. +// +// Preconditions: d.dirMu must be locked. +func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry, stat linux.Statx) error { vfsObj := fs.vfsfs.VirtualFilesystem() // Get the path to the child dentry. This is only used to provide path @@ -312,7 +331,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err @@ -324,7 +343,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat }) if err == syserror.ENODATA { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -332,45 +351,135 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + } + + if d.isDir() && len(d.childrenNames) == 0 { + childrenOffString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: childrenOffsetXattr, + Size: sizeOfStringInt32, + }) + + if err == syserror.ENODATA { + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) + } + if err != nil { + return err + } + childrenOffset, err := strconv.Atoi(childrenOffString) + if err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + } + + childrenSizeString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: childrenSizeXattr, + Size: sizeOfStringInt32, + }) + + if err == syserror.ENODATA { + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) + } + if err != nil { + return err + } + childrenSize, err := strconv.Atoi(childrenSizeString) + if err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + } + + childrenNames := make([]byte, childrenSize) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(childrenOffset), vfs.ReadOptions{}); err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) + } + + if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { + return alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) + } } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: fd, Ctx: ctx, } var buf bytes.Buffer + d.hashMu.RLock() params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - ReadOffset: 0, + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + Children: d.childrenNames, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fs.alg.toLinuxHashAlg(), + ReadOffset: 0, // Set read size to 0 so only the metadata is verified. ReadSize: 0, Expected: d.hash, DataAndTreeInSameFile: false, } + d.hashMu.RUnlock() if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR { params.DataAndTreeInSameFile = true } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID + d.size = uint32(size) return nil } -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { + // If verity is enabled on child, we should check again whether + // the file and the corresponding Merkle tree are as expected, + // in order to catch deletion/renaming after the last time it's + // accessed. + if child.verityEnabled() { + vfsObj := fs.vfsfs.VirtualFilesystem() + // Get the path to the child dentry. This is only used + // to provide path information in failure case. + path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD) + if err != nil { + return nil, err + } + + childVD, err := parent.getLowerAt(ctx, vfsObj, name) + if err == syserror.ENOENT { + // The file was previously accessed. If the + // file does not exist now, it indicates an + // unexpected modification to the file system. + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + } + if err != nil { + return nil, err + } + defer childVD.DecRef(ctx) + + childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + // The Merkle tree file was previous accessed. If it + // does not exist now, it indicates an unexpected + // modification to the file system. + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + } + if err != nil { + return nil, err + } + + defer childMerkleVD.DecRef(ctx) + } + // If enabling verification on files/directories is not allowed // during runtime, all cached children are already verified. If // runtime enable is allowed and the parent directory is @@ -378,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // be cached before enabled. if fs.allowRuntimeEnable { if parent.verityEnabled() { - if _, err := fs.verifyChild(ctx, parent, child); err != nil { + if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil { return nil, err } } @@ -394,7 +503,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s if err != nil { return nil, err } - if err := fs.verifyStat(ctx, child, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil { return nil, err } } @@ -414,112 +523,64 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s return child, nil } -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() - childFilename := fspath.Parse(name) - childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: childFilename, - }, &vfs.GetDentryOptions{}) - - // We will handle ENOENT separately, as it may indicate unexpected - // modifications to the file system, and may cause a sentry panic. - if childErr != nil && childErr != syserror.ENOENT { - return nil, childErr + if parent.verityEnabled() { + if _, ok := parent.childrenNames[name]; !ok { + return nil, syserror.ENOENT + } } - // The dentry needs to be cleaned up if any error occurs. IncRef will be - // called if a verity child dentry is successfully created. - if childErr == nil { - defer childVD.DecRef(ctx) + parentPath, err := vfsObj.PathnameWithDeleted(ctx, parent.fs.rootDentry.lowerVD, parent.lowerVD) + if err != nil { + return nil, err } - childMerkleFilename := merklePrefix + name - childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) - - // We will handle ENOENT separately, as it may indicate unexpected - // modifications to the file system, and may cause a sentry panic. - if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { - return nil, childMerkleErr + childVD, err := parent.getLowerAt(ctx, vfsObj, name) + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) + } + if err != nil { + return nil, err } // The dentry needs to be cleaned up if any error occurs. IncRef will be // called if a verity child dentry is successfully created. - if childMerkleErr == nil { - defer childMerkleVD.DecRef(ctx) - } + defer childVD.DecRef(ctx) - // Get the path to the parent dentry. This is only used to provide path - // information in failure case. - parentPath, err := vfsObj.PathnameWithDeleted(ctx, parent.fs.rootDentry.lowerVD, parent.lowerVD) - if err != nil { + childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + if err == syserror.ENOENT { + if !fs.allowRuntimeEnable { + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) + } + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(merklePrefix + name), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + if err != nil { + return nil, err + } + } + if err != nil && err != syserror.ENOENT { return nil, err } - // TODO(b/166474175): Investigate all possible errors of childErr and - // childMerkleErr, and make sure we differentiate all errors that - // indicate unexpected modifications to the file system from the ones - // that are not harmful. - if childErr == syserror.ENOENT && childMerkleErr == nil { - // Failed to get child file/directory dentry. However the - // corresponding Merkle tree is found. This indicates an - // unexpected modification to the file system that - // removed/renamed the child. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) - } else if childErr == nil && childMerkleErr == syserror.ENOENT { - // If in allowRuntimeEnable mode, and the Merkle tree file is - // not created yet, we create an empty Merkle tree file, so that - // if the file is enabled through ioctl, we have the Merkle tree - // file open and ready to use. - // This may cause empty and unused Merkle tree files in - // allowRuntimeEnable mode, if they are never enabled. This - // does not affect verification, as we rely on cached hash to - // decide whether to perform verification, not the existence of - // the Merkle tree file. Also, those Merkle tree files are - // always hidden and cannot be accessed by verity fs users. - if fs.allowRuntimeEnable { - childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT, - Mode: 0644, - }) - if err != nil { - return nil, err - } - childMerkleFD.DecRef(ctx) - childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) - if err != nil { - return nil, err - } - } else { - // If runtime enable is not allowed. This indicates an - // unexpected modification to the file system that - // removed/renamed the Merkle tree file. - return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) - } - } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { - // Both the child and the corresponding Merkle tree are missing. - // This could be an unexpected modification or due to incorrect - // parameter. - // TODO(b/167752508): Investigate possible ways to differentiate - // cases that both files are deleted from cases that they never - // exist in the file system. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) - } + // The dentry needs to be cleaned up if any error occurs. IncRef will be + // called if a verity child dentry is successfully created. + defer childMerkleVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{ @@ -549,18 +610,19 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, child.mode = uint32(stat.Mode) child.uid = stat.UID child.gid = stat.GID + child.childrenNames = make(map[string]struct{}) // Verify child hash. This should always be performed unless in // allowRuntimeEnable mode and the parent directory hasn't been enabled // yet. if parent.verityEnabled() { - if _, err := fs.verifyChild(ctx, parent, child); err != nil { + if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil { child.destroyLocked(ctx) return nil, err } } if child.verityEnabled() { - if err := fs.verifyStat(ctx, child, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil { child.destroyLocked(ctx) return nil, err } @@ -574,7 +636,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // rp.Start().Impl().(*dentry)). It does not check that the returned directory // is searchable by the provider of rp. // -// Preconditions: fs.renameMu must be locked. !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() @@ -762,7 +826,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -785,7 +849,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -810,7 +874,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -828,7 +892,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } @@ -915,11 +979,13 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } + d.dirMu.Lock() if d.verityEnabled() { - if err := fs.verifyStat(ctx, d, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { return linux.Statx{}, err } } + d.dirMu.Unlock() return stat, nil } diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go new file mode 100644 index 000000000..46b064342 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 8dc9e26bc..9571ce9f1 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -19,10 +19,24 @@ // The verity file system is read-only, except for one case: when // allowRuntimeEnable is true, additional Merkle files can be generated using // the FS_IOC_ENABLE_VERITY ioctl. +// +// Lock order: +// +// filesystem.renameMu +// dentry.dirMu +// fileDescription.mu +// filesystem.verityMu +// dentry.hashMu +// +// Locking dentry.dirMu in multiple dentries requires that parent dentries are +// locked before child dentries, and that filesystem.renameMu is locked to +// stabilize this relationship. package verity import ( + "encoding/json" "fmt" + "math" "strconv" "sync/atomic" @@ -31,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -41,32 +56,72 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Name is the default filesystem name. -const Name = "verity" +const ( + // Name is the default filesystem name. + Name = "verity" + + // merklePrefix is the prefix of the Merkle tree files. For example, the Merkle + // tree file for "/foo" is "/.merkle.verity.foo". + merklePrefix = ".merkle.verity." + + // merkleRootPrefix is the prefix of the Merkle tree root file. This + // needs to be different from merklePrefix to avoid name collision. + merkleRootPrefix = ".merkleroot.verity." -// merklePrefix is the prefix of the Merkle tree files. For example, the Merkle -// tree file for "/foo" is "/.merkle.verity.foo". -const merklePrefix = ".merkle.verity." + // merkleOffsetInParentXattr is the extended attribute name specifying the + // offset of the child hash in its parent's Merkle tree. + merkleOffsetInParentXattr = "user.merkle.offset" -// merkleoffsetInParentXattr is the extended attribute name specifying the -// offset of child hash in its parent's Merkle tree. -const merkleOffsetInParentXattr = "user.merkle.offset" + // merkleSizeXattr is the extended attribute name specifying the size of data + // hashed by the corresponding Merkle tree. For a regular file, this is the + // file size. For a directory, this is the size of all its children's hashes. + merkleSizeXattr = "user.merkle.size" -// merkleSizeXattr is the extended attribute name specifying the size of data -// hashed by the corresponding Merkle tree. For a file, it's the size of the -// whole file. For a directory, it's the size of all its children's hashes. -const merkleSizeXattr = "user.merkle.size" + // childrenOffsetXattr is the extended attribute name specifying the + // names of the offset of the serialized children names in the Merkle + // tree file. + childrenOffsetXattr = "user.merkle.childrenOffset" -// sizeOfStringInt32 is the size for a 32 bit integer stored as string in -// extended attributes. The maximum value of a 32 bit integer is 10 digits. -const sizeOfStringInt32 = 10 + // childrenSizeXattr is the extended attribute name specifying the size + // of the serialized children names. + childrenSizeXattr = "user.merkle.childrenSize" -// noCrashOnVerificationFailure indicates whether the sandbox should panic -// whenever verification fails. If true, an error is returned instead of -// panicking. This should only be set for tests. -// TOOD(b/165661693): Decide whether to panic or return error based on this -// flag. -var noCrashOnVerificationFailure bool + // sizeOfStringInt32 is the size for a 32 bit integer stored as string in + // extended attributes. The maximum value of a 32 bit integer has 10 digits. + sizeOfStringInt32 = 10 +) + +var ( + // noCrashOnVerificationFailure indicates whether the sandbox should panic + // whenever verification fails. If true, an error is returned instead of + // panicking. This should only be set for tests. + noCrashOnVerificationFailure bool + + // verityMu synchronizes concurrent operations that enable verity and perform + // verification checks. + verityMu sync.RWMutex +) + +// HashAlgorithm is a type specifying the algorithm used to hash the file +// content. +type HashAlgorithm int + +// Currently supported hashing algorithms include SHA256 and SHA512. +const ( + SHA256 HashAlgorithm = iota + SHA512 +) + +func (alg HashAlgorithm) toLinuxHashAlg() int { + switch alg { + case SHA256: + return linux.FS_VERITY_HASH_ALG_SHA256 + case SHA512: + return linux.FS_VERITY_HASH_ALG_SHA512 + default: + return 0 + } +} // FilesystemType implements vfs.FilesystemType. // @@ -97,6 +152,10 @@ type filesystem struct { // stores the root hash of the whole file system in bytes. rootDentry *dentry + // alg is the algorithms used to hash the files in the verity file + // system. + alg HashAlgorithm + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -125,6 +184,10 @@ type InternalFilesystemOptions struct { // LowerName is the name of the filesystem wrapped by verity fs. LowerName string + // Alg is the algorithms used to hash the files in the verity file + // system. + Alg HashAlgorithm + // RootHash is the root hash of the overall verity file system. RootHash []byte @@ -153,10 +216,10 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. -func alertIntegrityViolation(err error, msg string) error { +// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +func alertIntegrityViolation(msg string) error { if noCrashOnVerificationFailure { - return err + return syserror.EIO } panic(msg) } @@ -183,6 +246,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs := &filesystem{ creds: creds.Fork(), + alg: iopts.Alg, lowerMount: mnt, allowRuntimeEnable: iopts.AllowRuntimeEnable, } @@ -195,7 +259,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merklePrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -236,7 +300,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") + return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") } d.lowerMerkleVD = lowerMerkleVD @@ -258,14 +322,77 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.uid = stat.UID d.gid = stat.GID d.hash = make([]byte, len(iopts.RootHash)) + d.childrenNames = make(map[string]struct{}) if !fs.allowRuntimeEnable { - if err := fs.verifyStat(ctx, d, stat); err != nil { + // Get children names from the underlying file system. + offString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.GetXattrOptions{ + Name: childrenOffsetXattr, + Size: sizeOfStringInt32, + }) + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) + } + if err != nil { + return nil, nil, err + } + + off, err := strconv.Atoi(offString) + if err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + } + + sizeString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.GetXattrOptions{ + Name: childrenSizeXattr, + Size: sizeOfStringInt32, + }) + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) + } + if err != nil { + return nil, nil, err + } + size, err := strconv.Atoi(sizeString) + if err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + } + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err == syserror.ENOENT { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) + } + if err != nil { + return nil, nil, err + } + + childrenNames := make([]byte, size) + if _, err := lowerMerkleFD.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(off), vfs.ReadOptions{}); err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) + } + + if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { + return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) + } + + if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { return nil, nil, err } } + d.hashMu.Lock() copy(d.hash, iopts.RootHash) + d.hashMu.Unlock() d.vfsd.Init(d) fs.rootDentry = d @@ -289,11 +416,13 @@ type dentry struct { // fs is the owning filesystem. fs is immutable. fs *filesystem - // mode, uid and gid are the file mode, owner, and group of the file in - // the underlying file system. + // mode, uid, gid and size are the file mode, owner, group, and size of + // the file in the underlying file system. They are set when a dentry + // is initialized, and never modified. mode uint32 uid uint32 gid uint32 + size uint32 // parent is the dentry corresponding to this dentry's parent directory. // name is this dentry's name in parent. If this dentry is a filesystem @@ -310,15 +439,24 @@ type dentry struct { dirMu sync.Mutex `state:"nosave"` children map[string]*dentry - // lowerVD is the VirtualDentry in the underlying file system. + // childrenNames stores the name of all children of the dentry. This is + // used by verity to check whether a child is expected. This is only + // populated by enableVerity. childrenNames is also protected by dirMu. + childrenNames map[string]struct{} + + // lowerVD is the VirtualDentry in the underlying file system. It is + // never modified after initialized. lowerVD vfs.VirtualDentry // lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree - // in the underlying file system. + // in the underlying file system. It is never modified after + // initialized. lowerMerkleVD vfs.VirtualDentry - // hash is the calculated hash for the current file or directory. - hash []byte + // hash is the calculated hash for the current file or directory. hash + // is protected by hashMu. + hashMu sync.RWMutex `state:"nosave"` + hash []byte } // newDentry creates a new dentry representing the given verity file. The @@ -331,22 +469,29 @@ func (fs *filesystem) newDentry() *dentry { fs: fs, } d.vfsd.Init(d) + refsvfs2.Register(d) return d } // IncRef implements vfs.DentryImpl.IncRef. func (d *dentry) IncRef() { - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } return true } } @@ -354,15 +499,31 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + r := atomic.AddInt64(&d.refs, -1) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } + if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { + } else if r < 0 { panic("verity.dentry.DecRef() called without holding a reference") } } +func (d *dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } + if r == 0 { + d.checkDropLocked(ctx) + } else if r < 0 { + panic("verity.dentry.decRefLocked() called without holding a reference") + } +} + // checkDropLocked should be called after d's reference count becomes 0 or it // becomes deleted. func (d *dentry) checkDropLocked(ctx context.Context) { @@ -378,7 +539,9 @@ func (d *dentry) checkDropLocked(ctx context.Context) { // destroyLocked destroys the dentry. // -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. +// Preconditions: +// * d.fs.renameMu must be locked for writing. +// * d.refs == 0. func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: @@ -393,23 +556,36 @@ func (d *dentry) destroyLocked(ctx context.Context) { if d.lowerVD.Ok() { d.lowerVD.DecRef(ctx) } - if d.lowerMerkleVD.Ok() { d.lowerMerkleVD.DecRef(ctx) } - if d.parent != nil { d.parent.dirMu.Lock() if !d.vfsd.IsDead() { delete(d.parent.children, d.name) } d.parent.dirMu.Unlock() - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked(ctx) - } else if refs < 0 { - panic("verity.dentry.DecRef() called without holding a reference") - } + d.parent.decRefLocked(ctx) } + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "verity.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -445,9 +621,21 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) // mode, it returns true if the target has been enabled with // ioctl(FS_IOC_ENABLE_VERITY). func (d *dentry) verityEnabled() bool { + d.hashMu.RLock() + defer d.hashMu.RUnlock() return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } +// getLowerAt returns the dentry in the underlying file system, which is +// represented by filename relative to d. +func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) { + return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(filename), + }, &vfs.GetDentryOptions{}) +} + func (d *dentry) readlink(ctx context.Context) (string, error) { return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: d.lowerVD, @@ -489,6 +677,10 @@ type fileDescription struct { // directory that contains the current file/directory. This is only used // if allowRuntimeEnable is set to true. parentMerkleWriter *vfs.FileDescription + + // off is the file offset. off is protected by mu. + mu sync.Mutex `state:"nosave"` + off int64 } // Release implements vfs.FileDescriptionImpl.Release. @@ -510,11 +702,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu if err != nil { return linux.Statx{}, err } + fd.d.dirMu.Lock() if fd.d.verityEnabled() { - if err := fd.d.fs.verifyStat(ctx, fd.d, stat); err != nil { + if err := fd.d.fs.verifyStatAndChildrenLocked(ctx, fd.d, stat); err != nil { return linux.Statx{}, err } } + fd.d.dirMu.Unlock() return stat, nil } @@ -524,28 +718,59 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } -// generateMerkle generates a Merkle tree file for fd. If fd points to a file -// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash -// of the generated Merkle tree and the data size is returned. If fd points to -// a regular file, the data is the content of the file. If fd points to a -// directory, the data is all hahes of its children, written to the Merkle tree -// file. -func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) { - fdReader := vfs.FileReadWriteSeeker{ +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + n := int64(0) + switch whence { + case linux.SEEK_SET: + // use offset as specified + case linux.SEEK_CUR: + n = fd.off + case linux.SEEK_END: + n = int64(fd.d.size) + default: + return 0, syserror.EINVAL + } + if offset > math.MaxInt64-n { + return 0, syserror.EINVAL + } + offset += n + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + +// generateMerkleLocked generates a Merkle tree file for fd. If fd points to a +// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The +// hash of the generated Merkle tree and the data size is returned. If fd +// points to a regular file, the data is the content of the file. If fd points +// to a directory, the data is all hahes of its children, written to the Merkle +// tree file. +// +// Preconditions: fd.d.fs.verityMu must be locked. +func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) { + fdReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } - merkleWriter := vfs.FileReadWriteSeeker{ + merkleWriter := FileReadWriteSeeker{ FD: fd.merkleWriter, Ctx: ctx, } params := &merkletree.GenerateParams{ TreeReader: &merkleReader, TreeWriter: &merkleWriter, + Children: fd.d.childrenNames, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), } switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { @@ -596,9 +821,48 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, return hash, uint64(params.Size), err } +// recordChildrenLocked writes the names of fd's children into the +// corresponding Merkle tree file, and saves the offset/size of the map into +// xattrs. +// +// Preconditions: +// * fd.d.fs.verityMu must be locked. +// * fd.d.isDir() == true. +func (fd *fileDescription) recordChildrenLocked(ctx context.Context) error { + // Record the children names in the Merkle tree file. + childrenNames, err := json.Marshal(fd.d.childrenNames) + if err != nil { + return err + } + + stat, err := fd.merkleWriter.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return err + } + + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: childrenOffsetXattr, + Value: strconv.Itoa(int(stat.Size)), + }); err != nil { + return err + } + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: childrenSizeXattr, + Value: strconv.Itoa(len(childrenNames)), + }); err != nil { + return err + } + + if _, err = fd.merkleWriter.Write(ctx, usermem.BytesIOSequence(childrenNames), vfs.WriteOptions{}); err != nil { + return err + } + + return nil +} + // enableVerity enables verity features on fd by generating a Merkle tree file // and stores its hash in its parent directory's Merkle tree. -func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) { +func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { if !fd.d.fs.allowRuntimeEnable { return 0, syserror.EPERM } @@ -611,10 +875,10 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } - hash, dataSize, err := fd.generateMerkle(ctx) + hash, dataSize, err := fd.generateMerkleLocked(ctx) if err != nil { return 0, err } @@ -641,6 +905,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui }); err != nil { return 0, err } + + // Add the current child's name to parent's childrenNames. + fd.d.parent.childrenNames[fd.d.name] = struct{}{} } // Record the size of the data being hashed for fd. @@ -650,15 +917,29 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui }); err != nil { return 0, err } - fd.d.hash = append(fd.d.hash, hash...) + + if fd.d.isDir() { + if err := fd.recordChildrenLocked(ctx); err != nil { + return 0, err + } + } + fd.d.hashMu.Lock() + fd.d.hash = hash + fd.d.hashMu.Unlock() return 0, nil } // measureVerity returns the hash of fd, saved in verityDigest. -func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { +func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest usermem.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } var metadata linux.DigestMetadata + fd.d.hashMu.RLock() + defer fd.d.hashMu.RUnlock() + // If allowRuntimeEnable is true, an empty fd.d.hash indicates that // verity is not enabled for the file. If allowRuntimeEnable is false, // this is an integrity violation because all files should have verity @@ -667,7 +948,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found") + return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -692,16 +973,21 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve return 0, err } -func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) { +func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) { f := int32(0) + fd.d.hashMu.RLock() // All enabled files should store a hash. This flag is not settable via // FS_IOC_SETFLAGS. if len(fd.d.hash) != 0 { f |= linux.FS_VERITY_FL } + fd.d.hashMu.RUnlock() t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } _, err := primitive.CopyInt32Out(t, flags, f) return 0, err } @@ -710,11 +996,11 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := args[1].Uint(); cmd { case linux.FS_IOC_ENABLE_VERITY: - return fd.enableVerity(ctx, uio) + return fd.enableVerity(ctx) case linux.FS_IOC_MEASURE_VERITY: - return fd.measureVerity(ctx, uio, args[2].Pointer()) + return fd.measureVerity(ctx, args[2].Pointer()) case linux.FS_IOC_GETFLAGS: - return fd.verityFlags(ctx, uio, args[2].Pointer()) + return fd.verityFlags(ctx, args[2].Pointer()) default: // TODO(b/169682228): Investigate which ioctl commands should // be allowed. @@ -722,6 +1008,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. } } +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Implement Read with PRead by setting offset. + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.mu.Unlock() + return n, err +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // No need to verify if the file is not enabled yet in @@ -742,7 +1038,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -752,39 +1048,54 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } - dataReader := vfs.FileReadWriteSeeker{ + dataReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } + fd.d.hashMu.RLock() n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + Children: fd.d.childrenNames, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), Expected: fd.d.hash, DataAndTreeInSameFile: false, }) + fd.d.hashMu.RUnlock() if err != nil { - return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) @@ -794,3 +1105,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) } + +// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as +// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. +type FileReadWriteSeeker struct { + FD *vfs.FileDescription + Ctx context.Context + ROpts vfs.ReadOptions + WOpts vfs.WriteOptions +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) + return int(n), err +} + +// Read implements io.ReadWriteSeeker.Read. +func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.Read(f.Ctx, dst, f.ROpts) + return int(n), err +} + +// Seek implements io.ReadWriteSeeker.Seek. +func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { + return f.FD.Seek(f.Ctx, offset, int32(whence)) +} + +// WriteAt implements io.WriterAt.WriteAt. +func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) + return int(n), err +} + +// Write implements io.ReadWriteSeeker.Write. +func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { + buf := usermem.BytesIOSequence(p) + n, err := f.FD.Write(f.Ctx, buf, f.WOpts) + return int(n), err +} diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index e301d35f5..bd948715f 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -18,6 +18,7 @@ import ( "fmt" "io" "math/rand" + "strconv" "testing" "time" @@ -25,27 +26,59 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) -// rootMerkleFilename is the name of the root Merkle tree file. -const rootMerkleFilename = "root.verity" +const ( + // rootMerkleFilename is the name of the root Merkle tree file. + rootMerkleFilename = "root.verity" + // maxDataSize is the maximum data size of a test file. + maxDataSize = 100000 +) + +var hashAlgs = []HashAlgorithm{SHA256, SHA512} -// maxDataSize is the maximum data size written to the file for test. -const maxDataSize = 100000 +func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry { + t.Helper() + d, ok := vd.Dentry().Impl().(*dentry) + if !ok { + t.Fatalf("can't assert %T as a *dentry", vd) + } + return d +} + +// dentryFromFD returns the dentry corresponding to fd. +func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry { + t.Helper() + f, ok := fd.Impl().(*fileDescription) + if !ok { + t.Fatalf("can't assert %T as a *fileDescription", fd) + } + return f.d +} // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) { +func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + t.Helper() + k, err := testutil.Boot() + if err != nil { + t.Fatalf("testutil.Boot: %v", err) + } + + ctx := k.SupervisorContext() + rand.Seed(time.Now().UnixNano()) vfsObj := &vfs.VirtualFilesystem{} if err := vfsObj.Init(ctx); err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -61,39 +94,125 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v InternalData: InternalFilesystemOptions{ RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", + Alg: hashAlg, AllowRuntimeEnable: true, NoCrashOnVerificationFailure: true, }, }, }) if err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) } root := mntns.Root() root.IncRef() - t.Helper() + + // Use lowerRoot in the task as we modify the lower file system + // directly in many tests. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot) + if err != nil { + t.Fatalf("testutil.CreateTask: %v", err) + } + t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) }) - return vfsObj, root, nil + return vfsObj, root, task, nil } -// newFileFD creates a new file in the verity mount, and returns the FD. The FD -// points to a file that has random data generated. -func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { - creds := auth.CredentialsFromContext(ctx) - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD +// openVerityAt opens a verity file. +// +// TODO(chongc): release reference from opening the file when done. +func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { + return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: vd, + Start: vd, + Path: fspath.Parse(path), + }, &vfs.OpenOptions{ + Flags: flags, + Mode: mode, + }) +} - // Create the file in the underlying file system. - lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(filePath), +// openLowerAt opens the file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. +func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { + return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(path), + }, &vfs.OpenOptions{ + Flags: flags, + Mode: mode, + }) +} + +// openLowerMerkleAt opens the Merkle file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. +func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { + return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: linux.ModeRegular | mode, + Flags: flags, + Mode: mode, }) +} + +// unlinkLowerAt deletes the file in the underlying file system. +func (d *dentry) unlinkLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error { + return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(path), + }) +} + +// unlinkLowerMerkleAt deletes the Merkle file in the underlying file system. +func (d *dentry) unlinkLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error { + return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(merklePrefix + path), + }) +} + +// renameLowerAt renames file name to newName in the underlying file system. +func (d *dentry) renameLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error { + return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(name), + }, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(newName), + }, &vfs.RenameOptions{}) +} + +// renameLowerMerkleAt renames Merkle file name to newName in the underlying +// file system. +func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error { + return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(merklePrefix + name), + }, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(merklePrefix + newName), + }, &vfs.RenameOptions{}) +} + +// newFileFD creates a new file in the verity mount, and returns the FD. The FD +// points to a file that has random data generated. +func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { + // Create the file in the underlying file system. + lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) if err != nil { return nil, 0, err } @@ -116,20 +235,24 @@ func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.Virt lowerFD.DecRef(ctx) // Now open the verity file descriptor. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filePath), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular | mode, - }) + fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) return fd, dataSize, err } -// corruptRandomBit randomly flips a bit in the file represented by fd. -func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { - // Flip a random bit in the underlying file. +// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD. +func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) { + // Create the file in the underlying file system. + _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) + if err != nil { + return nil, err + } + // Now open the verity file descriptor. + fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) + return fd, err +} + +// flipRandomBit randomly flips a bit in the file represented by fd. +func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { randomPos := int64(rand.Intn(size)) byteToModify := make([]byte, 1) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { @@ -142,207 +265,299 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er return nil } +func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + t.Helper() + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("enable verity: %v", err) + } +} + // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. func TestOpen(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { + t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { + t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } } +} - // Ensure that the corresponding Merkle tree file is created. - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) - } - - // Ensure the root merkle tree file is created. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + rootMerkleFilename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) +// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity +// file succeeds after enabling verity for it. +func TestPReadUnmodifiedFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + enableVerity(ctx, t, fd) + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } } -// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file -// succeeds after enabling verity for it. +// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity +// file succeeds after enabling verity for it. func TestReadUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - buf := make([]byte, size) - n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.PRead: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + enableVerity(ctx, t, fd) + + buf := make([]byte, size) + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } +} - if n != int64(size) { - t.Errorf("fd.PRead got read length %d, want %d", n, size) +// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity +// file succeeds after enabling verity for it. +func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-empty-file" + fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newEmptyFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + enableVerity(ctx, t, fd) + + var buf []byte + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != 0 { + t.Errorf("fd.Read got read length %d, expected 0", n) + } } } // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Ensure reopening the verity enabled file succeeds. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != nil { - t.Errorf("reopen enabled file failed: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + enableVerity(ctx, t, fd) + + // Ensure reopening the verity enabled file succeeds. + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } } } -// TestModifiedFileFails ensures that read from a modified verity file fails. -func TestModifiedFileFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) +// TestOpenNonexistentFile ensures that opening a nonexistent file does not +// trigger verification failure, even if the parent directory is verified. +func TestOpenNonexistentFile(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t, SHA256) if err != nil { t.Fatalf("newVerityRoot: %v", err) } filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerFD that's read/writable. - lowerVD := fd.Impl().(*fileDescription).d.lowerVD + // Enable verity on the file and confirms a normal read succeeds. + enableVerity(ctx, t, fd) - lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerVD, - Start: lowerVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) + // Enable verity on the parent directory. + parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } + enableVerity(ctx, t, parentFD) - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + // Ensure open an unexpected file in the parent directory fails with + // ENOENT rather than verification failure. + if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); err != syserror.ENOENT { + t.Errorf("OpenAt unexpected error: %v", err) } +} - // Confirm that read from the modified file fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded with modified file") +// TestPReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestPReadModifiedFileFails(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + enableVerity(ctx, t, fd) + + // Open a new lowerFD that's read/writable. + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded, expected failure") + } + } +} + +// TestReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestReadModifiedFileFails(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + enableVerity(ctx, t, fd) + + // Open a new lowerFD that's read/writable. + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.Read succeeded, expected failure") + } } } // TestModifiedMerkleFails ensures that read from a verity file fails if the // corresponding Merkle tree file is modified. func TestModifiedMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD - - lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerMerkleVD, - Start: lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the Merkle tree file. - stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - merkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - // Confirm that read from a file with modified Merkle tree fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - fmt.Println(buf) - t.Fatalf("fd.PRead succeeded with modified Merkle file") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + enableVerity(ctx, t, fd) + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Errorf("lowerMerkleFD.Stat: %v", err) + } + + if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { + t.Fatalf("flipRandomBit: %v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } } } @@ -350,142 +565,239 @@ func TestModifiedMerkleFails(t *testing.T) { // verity enabled directory fails if the hashes related to the target file in // the parent Merkle tree file is modified. func TestModifiedParentMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Enable verity on the parent directory. - parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD - - parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: parentLowerMerkleVD, - Start: parentLowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the parent Merkle tree file. - // This parent directory contains only one child, so any random - // modification in the parent Merkle tree should cause verification - // failure when opening the child file. - stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - parentMerkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - parentLowerMerkleFD.DecRef(ctx) - - // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err == nil { - t.Errorf("OpenAt file with modified parent Merkle succeeded") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + enableVerity(ctx, t, fd) + + // Enable verity on the parent directory. + parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + enableVerity(ctx, t, parentFD) + + // Open a new lowerMerkleFD that's read/writable. + parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the parent Merkle tree file. + // This parent directory contains only one child, so any random + // modification in the parent Merkle tree should cause verification + // failure when opening the child file. + sizeString, err := parentLowerMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: childrenOffsetXattr, + Size: sizeOfStringInt32, + }) + if err != nil { + t.Fatalf("parentLowerMerkleFD.GetXattr: %v", err) + } + parentMerkleSize, err := strconv.Atoi(sizeString) + if err != nil { + t.Fatalf("Failed convert size to int: %v", err) + } + if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("flipRandomBit: %v", err) + } + + parentLowerMerkleFD.DecRef(ctx) + + // Ensure reopening the verity enabled file fails. + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err == nil { + t.Errorf("OpenAt file with modified parent Merkle succeeded") + } } } // TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file // succeeds after enabling verity for it. func TestUnmodifiedStatSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms stat succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { - t.Errorf("fd.Stat: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm that stat succeeds. + enableVerity(ctx, t, fd) + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { + t.Errorf("fd.Stat: %v", err) + } } } // TestModifiedStatFails checks that getting stat for a file with modified stat // should fail. func TestModifiedStatFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + enableVerity(ctx, t, fd) + + lowerFD := fd.Impl().(*fileDescription).lowerFD + // Change the stat of the underlying file, and check that stat fails. + if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: uint32(linux.STATX_MODE), + Mode: 0777, + }, + }); err != nil { + t.Fatalf("lowerFD.SetStat: %v", err) + } - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { + t.Errorf("fd.Stat succeeded when it should fail") + } } +} - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) +// TestOpenDeletedFileFails ensures that opening a deleted verity enabled file +// and/or the corresponding Merkle tree file fails with the verity error. +func TestOpenDeletedFileFails(t *testing.T) { + testCases := []struct { + name string + // The original file is removed if changeFile is true. + changeFile bool + // The Merkle tree file is removed if changeMerkleFile is true. + changeMerkleFile bool + }{ + { + name: "FileOnly", + changeFile: true, + changeMerkleFile: false, + }, + { + name: "MerkleOnly", + changeFile: false, + changeMerkleFile: true, + }, + { + name: "FileAndMerkle", + changeFile: true, + changeMerkleFile: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t, SHA256) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + enableVerity(ctx, t, fd) + + if tc.changeFile { + if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } + }) } +} - lowerFD := fd.Impl().(*fileDescription).lowerFD - // Change the stat of the underlying file, and check that stat fails. - if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: uint32(linux.STATX_MODE), - Mode: 0777, +// TestOpenRenamedFileFails ensures that opening a renamed verity enabled file +// and/or the corresponding Merkle tree file fails with the verity error. +func TestOpenRenamedFileFails(t *testing.T) { + testCases := []struct { + name string + // The original file is renamed if changeFile is true. + changeFile bool + // The Merkle tree file is renamed if changeMerkleFile is true. + changeMerkleFile bool + }{ + { + name: "FileOnly", + changeFile: true, + changeMerkleFile: false, + }, + { + name: "MerkleOnly", + changeFile: false, + changeMerkleFile: true, + }, + { + name: "FileAndMerkle", + changeFile: true, + changeMerkleFile: true, }, - }); err != nil { - t.Fatalf("lowerFD.SetStat: %v", err) } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { - t.Errorf("fd.Stat succeeded when it should fail") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t, SHA256) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + enableVerity(ctx, t, fd) + + newFilename := "renamed-test-file" + if tc.changeFile { + if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil { + t.Fatalf("RenameAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } + }) } } diff --git a/pkg/sentry/fsmetric/BUILD b/pkg/sentry/fsmetric/BUILD new file mode 100644 index 000000000..4e86fbdd8 --- /dev/null +++ b/pkg/sentry/fsmetric/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "fsmetric", + srcs = ["fsmetric.go"], + visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/metric"], +) diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go new file mode 100644 index 000000000..7e535b527 --- /dev/null +++ b/pkg/sentry/fsmetric/fsmetric.go @@ -0,0 +1,83 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fsmetric defines filesystem metrics that are used by both VFS1 and +// VFS2. +// +// TODO(gvisor.dev/issue/1624): Once VFS1 is deleted, inline these metrics into +// VFS2. +package fsmetric + +import ( + "time" + + "gvisor.dev/gvisor/pkg/metric" +) + +// RecordWaitTime enables the ReadWait, GoferReadWait9P, GoferReadWaitHost, and +// TmpfsReadWait metrics. Enabling this comes at a CPU cost due to performing +// three clock reads per read call. +// +// Note that this is only performed in the direct read path, and may not be +// consistently applied for other forms of reads, such as splice. +var RecordWaitTime = false + +// Metrics that apply to all filesystems. +var ( + Opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.") + Reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.") + ReadWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.") +) + +// Metrics that only apply to fs/gofer and fsimpl/gofer. +var ( + GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.") + GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.") + GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.") + GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") + GoferReadWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.") + GoferReadsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.") + GoferReadWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.") +) + +// Metrics that only apply to fs/tmpfs and fsimpl/tmpfs. +var ( + TmpfsOpensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.") + TmpfsOpensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.") + TmpfsReads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.") + TmpfsReadWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.") +) + +// StartReadWait indicates the beginning of a file read. +func StartReadWait() time.Time { + if !RecordWaitTime { + return time.Time{} + } + return time.Now() +} + +// FinishReadWait indicates the end of a file read whose time is accounted by +// m. start must be the value returned by the corresponding call to +// StartReadWait. +// +// FinishReadWait is marked nosplit for performance since it's often called +// from defer statements, which prevents it from being inlined +// (https://github.com/golang/go/issues/38471). +//go:nosplit +func FinishReadWait(m *metric.Uint64Metric, start time.Time) { + if !RecordWaitTime { + return + } + m.IncrementBy(uint64(time.Since(start).Nanoseconds())) +} diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD index 364a78306..db3b0d0a0 100644 --- a/pkg/sentry/hostfd/BUILD +++ b/pkg/sentry/hostfd/BUILD @@ -6,10 +6,12 @@ go_library( name = "hostfd", srcs = [ "hostfd.go", + "hostfd_linux.go", "hostfd_unsafe.go", ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/log", "//pkg/safemem", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/goid/empty_test.go b/pkg/sentry/hostfd/hostfd_linux.go index c0a4b17ab..1cabc848f 100644 --- a/pkg/goid/empty_test.go +++ b/pkg/sentry/hostfd/hostfd_linux.go @@ -12,11 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !race +package hostfd -package goid - -import "testing" - -// TestNothing exists to make the build system happy. -func TestNothing(t *testing.T) {} +// maxIov is the maximum permitted size of a struct iovec array. +const maxIov = 1024 // UIO_MAXIOV diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go index cd4dc67fb..694371b1c 100644 --- a/pkg/sentry/hostfd/hostfd_unsafe.go +++ b/pkg/sentry/hostfd/hostfd_unsafe.go @@ -20,6 +20,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" ) @@ -44,6 +45,10 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6 } } else { iovs := safemem.IovecsFromBlockSeq(dsts) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { @@ -76,6 +81,10 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint } } else { iovs := safemem.IovecsFromBlockSeq(srcs) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index fbe6d6aa6..f31277d30 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,9 +32,13 @@ type Stack interface { InterfaceAddrs() map[int32][]InterfaceAddr // AddInterfaceAddr adds an address to the network interface identified by - // index. + // idx. AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // RemoveInterfaceAddr removes an address from the network interface + // identified by idx. + RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 1779cc6f3..9ebeba8a3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -15,6 +15,9 @@ package inet import ( + "bytes" + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } +// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { + interfaceAddrs, ok := s.InterfaceAddrsMap[idx] + if !ok { + return fmt.Errorf("unknown idx: %d", idx) + } + + var filteredAddrs []InterfaceAddr + for _, interfaceAddr := range interfaceAddrs { + if !bytes.Equal(interfaceAddr.Addr, addr.Addr) { + filteredAddrs = append(filteredAddrs, addr) + } + } + s.InterfaceAddrsMap[idx] = filteredAddrs + + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c0de72eef..0ee60569c 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -79,7 +79,7 @@ go_template_instance( out = "fd_table_refs.go", package = "kernel", prefix = "FDTable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FDTable", }, @@ -90,7 +90,7 @@ go_template_instance( out = "fs_context_refs.go", package = "kernel", prefix = "FSContext", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FSContext", }, @@ -101,7 +101,7 @@ go_template_instance( out = "ipc_namespace_refs.go", package = "kernel", prefix = "IPCNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "IPCNamespace", }, @@ -112,7 +112,7 @@ go_template_instance( out = "process_group_refs.go", package = "kernel", prefix = "ProcessGroup", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ProcessGroup", }, @@ -123,7 +123,7 @@ go_template_instance( out = "session_refs.go", package = "kernel", prefix = "Session", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Session", }, @@ -184,6 +184,7 @@ go_library( "task_exit.go", "task_futex.go", "task_identity.go", + "task_image.go", "task_list.go", "task_log.go", "task_net.go", @@ -224,12 +225,13 @@ go_library( "//pkg/cpuid", "//pkg/eventchannel", "//pkg/fspath", + "//pkg/goid", "//pkg/log", "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 1b9721534..0ddbe5ff6 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -19,7 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs_vfs2" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" ) @@ -27,7 +27,7 @@ import ( // +stateify savable type abstractEndpoint struct { ep transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter name string ns *AbstractSocketNamespace } @@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace { // its backing socket. type boundEndpoint struct { transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter } // Release implements transport.BoundEndpoint.Release. @@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp // // When the last reference managed by socket is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() @@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran // Remove removes the specified socket at name from the abstract socket // namespace, if it has not yet been replaced. -func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) { +func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) { a.mu.Lock() defer a.mu.Unlock() diff --git a/pkg/sentry/kernel/aio.go b/pkg/sentry/kernel/aio.go index 0ac78c0b8..ec36d1a49 100644 --- a/pkg/sentry/kernel/aio.go +++ b/pkg/sentry/kernel/aio.go @@ -15,10 +15,7 @@ package kernel import ( - "time" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" ) // AIOCallback is an function that does asynchronous I/O on behalf of a task. @@ -26,7 +23,7 @@ type AIOCallback func(context.Context) // QueueAIO queues an AIOCallback which will be run asynchronously. func (t *Task) QueueAIO(cb AIOCallback) { - ctx := taskAsyncContext{t: t} + ctx := t.AsyncContext() wg := &t.TaskSet().aioGoroutines wg.Add(1) go func() { @@ -34,48 +31,3 @@ func (t *Task) QueueAIO(cb AIOCallback) { wg.Done() }() } - -type taskAsyncContext struct { - context.NoopSleeper - t *Task -} - -// Debugf implements log.Logger.Debugf. -func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { - ctx.t.Debugf(format, v...) -} - -// Infof implements log.Logger.Infof. -func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { - ctx.t.Infof(format, v...) -} - -// Warningf implements log.Logger.Warningf. -func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { - ctx.t.Warningf(format, v...) -} - -// IsLogging implements log.Logger.IsLogging. -func (ctx taskAsyncContext) IsLogging(level log.Level) bool { - return ctx.t.IsLogging(level) -} - -// Deadline implements context.Context.Deadline. -func (ctx taskAsyncContext) Deadline() (time.Time, bool) { - return ctx.t.Deadline() -} - -// Done implements context.Context.Done. -func (ctx taskAsyncContext) Done() <-chan struct{} { - return ctx.t.Done() -} - -// Err implements context.Context.Err. -func (ctx taskAsyncContext) Err() error { - return ctx.t.Err() -} - -// Value implements context.Context.Value. -func (ctx taskAsyncContext) Value(key interface{}) interface{} { - return ctx.t.Value(key) -} diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index bb94769c4..a8596410f 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,8 +15,6 @@ package kernel import ( - "time" - "gvisor.dev/gvisor/pkg/context" ) @@ -98,18 +96,3 @@ func TaskFromContext(ctx context.Context) *Task { } return nil } - -// Deadline implements context.Context.Deadline. -func (*Task) Deadline() (time.Time, bool) { - return time.Time{}, false -} - -// Done implements context.Context.Done. -func (*Task) Done() <-chan struct{} { - return nil -} - -// Err implements context.Context.Err. -func (*Task) Err() error { - return nil -} diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 15519f0df..61aeca044 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent { // // Callback is called when one of the files we're polling becomes ready. It // moves said file to the readyList if it's currently in the waiting list. -func (p *pollEntry) Callback(*waiter.Entry) { +func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) { e := p.epoll e.listsMu.Lock() @@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) { f.EventRegister(&entry.waiter, entry.mask) // Check if the file happens to already be in a ready state. - ready := f.Readiness(entry.mask) & entry.mask - if ready != 0 { - entry.Callback(&entry.waiter) + if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 { + entry.Callback(&entry.waiter, ready) } } diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 2b3955598..f855f038b 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -8,11 +8,13 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/syserror", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 153d2cd9b..b66d61c6f 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -17,22 +17,45 @@ package fasync import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) -// New creates a new fs.FileAsync. -func New() fs.FileAsync { - return &FileAsync{} +// Table to convert waiter event masks into si_band siginfo codes. +// Taken from fs/fcntl.c:band_table. +var bandTable = map[waiter.EventMask]int64{ + // POLL_IN + waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM, + // POLL_OUT + waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND, + // POLL_ERR + waiter.EventErr: linux.EPOLLERR, + // POLL_PRI + waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND, + // POLL_HUP + waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR, } -// NewVFS2 creates a new vfs.FileAsync. -func NewVFS2() vfs.FileAsync { - return &FileAsync{} +// New returns a function that creates a new fs.FileAsync with the given file +// descriptor. +func New(fd int) func() fs.FileAsync { + return func() fs.FileAsync { + return &FileAsync{fd: fd} + } +} + +// NewVFS2 returns a function that creates a new vfs.FileAsync with the given +// file descriptor. +func NewVFS2(fd int) func() vfs.FileAsync { + return func() vfs.FileAsync { + return &FileAsync{fd: fd} + } } // FileAsync sends signals when the registered file is ready for IO. @@ -42,6 +65,12 @@ type FileAsync struct { // e is immutable after first use (which is protected by mu below). e waiter.Entry + // fd is the file descriptor to notify about. + // It is immutable, set at allocation time. This matches Linux semantics in + // fs/fcntl.c:fasync_helper. + // The fd value is passed to the signal recipient in siginfo.si_fd. + fd int + // regMu protects registeration and unregistration actions on e. // // regMu must be held while registration decisions are being made @@ -56,6 +85,10 @@ type FileAsync struct { mu sync.Mutex `state:"nosave"` requester *auth.Credentials registered bool + // signal is the signal to deliver upon I/O being available. + // The default value ("zero signal") means the default SIGIO signal will be + // delivered. + signal linux.Signal // Only one of the following is allowed to be non-nil. recipientPG *kernel.ProcessGroup @@ -64,10 +97,10 @@ type FileAsync struct { } // Callback sends a signal. -func (a *FileAsync) Callback(e *waiter.Entry) { +func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) { a.mu.Lock() + defer a.mu.Unlock() if !a.registered { - a.mu.Unlock() return } t := a.recipientT @@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) { } if t == nil { // No recipient has been registered. - a.mu.Unlock() return } c := t.Credentials() // Logic from sigio_perm in fs/fcntl.c. - if a.requester.EffectiveKUID == 0 || + permCheck := (a.requester.EffectiveKUID == 0 || a.requester.EffectiveKUID == c.SavedKUID || a.requester.EffectiveKUID == c.RealKUID || a.requester.RealKUID == c.SavedKUID || - a.requester.RealKUID == c.RealKUID { - t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO)) + a.requester.RealKUID == c.RealKUID) + if !permCheck { + return } - a.mu.Unlock() + signalInfo := &arch.SignalInfo{ + Signo: int32(linux.SIGIO), + Code: arch.SignalInfoKernel, + } + if a.signal != 0 { + signalInfo.Signo = int32(a.signal) + signalInfo.SetFD(uint32(a.fd)) + var band int64 + for m, bandCode := range bandTable { + if m&mask != 0 { + band |= bandCode + } + } + signalInfo.SetBand(band) + } + t.SendSignal(signalInfo) } // Register sets the file which will be monitored for IO events. @@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() { a.recipientTG = nil a.recipientPG = nil } + +// Signal returns which signal will be sent to the signal recipient. +// A value of zero means the signal to deliver wasn't customized, which means +// the default signal (SIGIO) will be delivered. +func (a *FileAsync) Signal() linux.Signal { + a.mu.Lock() + defer a.mu.Unlock() + return a.signal +} + +// SetSignal overrides which signal to send when I/O is available. +// The default behavior can be reset by specifying signal zero, which means +// to send SIGIO. +func (a *FileAsync) SetSignal(signal linux.Signal) error { + if signal != 0 && !signal.IsValid() { + return syserror.EINVAL + } + a.mu.Lock() + defer a.mu.Unlock() + a.signal = signal + return nil +} diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 0ec7344cd..7aba31587 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() - f.init() // Initialize table. + f.initNoLeakCheck() // Initialize table. f.used = 0 for fd, d := range m { if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { @@ -240,6 +240,10 @@ func (f *FDTable) String() string { case fileVFS2 != nil: vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem() + vd := fileVFS2.VirtualDentry() + if vd.Dentry() == nil { + panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2)) + } name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) if err != nil { fmt.Fprintf(&buf, "<err: %v>\n", err) diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index da79e6627..f17f9c59c 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -31,14 +31,21 @@ type descriptorTable struct { slice unsafe.Pointer `state:".(map[int32]*descriptor)"` } -// init initializes the table. +// initNoLeakCheck initializes the table without enabling leak checking. // -// TODO(gvisor.dev/1486): Enable leak check for FDTable. -func (f *FDTable) init() { +// This is used when loading an FDTable after S/R, during which the ref count +// object itself will enable leak checking if necessary. +func (f *FDTable) initNoLeakCheck() { var slice []unsafe.Pointer // Empty slice. atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) } +// init initializes the table with leak checking. +func (f *FDTable) init() { + f.initNoLeakCheck() + f.InitRefs() +} + // get gets a file entry. // // The boolean indicates whether this was in range. @@ -114,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 panic("VFS1 and VFS2 files set") } - slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) + slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) // Grow the table as required. - if last := int32(len(slice)); fd >= last { + if last := int32(len(*slicePtr)); fd >= last { end := fd + 1 if end < 2*last { end = 2 * last } - slice = append(slice, make([]unsafe.Pointer, end-last)...) - atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) + newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...) + slicePtr = &newSlice + atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr)) } + slice := *slicePtr + var desc *descriptor if file != nil || fileVFS2 != nil { desc = &descriptor{ diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index d46d1e1c1..dfde4deee 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -63,7 +63,7 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext { cwd: cwd, umask: umask, } - f.EnableLeakCheck() + f.InitRefs() return &f } @@ -76,7 +76,7 @@ func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext { cwdVFS2: cwd, umask: umask, } - f.EnableLeakCheck() + f.InitRefs() return &f } @@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext { f.root.IncRef() } - return &FSContext{ + ctx := &FSContext{ cwd: f.cwd, root: f.root, cwdVFS2: f.cwdVFS2, rootVFS2: f.rootVFS2, umask: f.umask, } + ctx.InitRefs() + return ctx } // WorkingDirectory returns the current working directory. @@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() - f.cwd.IncRef() + if f.cwd != nil { + f.cwd.IncRef() + } return f.cwd } // WorkingDirectoryVFS2 returns the current working directory. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.cwdVFS2.IncRef() + if f.cwdVFS2.Ok() { + f.cwdVFS2.IncRef() + } return f.cwdVFS2 } @@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent { // RootDirectoryVFS2 returns the current filesystem root. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.rootVFS2.IncRef() + if f.rootVFS2.Ok() { + f.rootVFS2.IncRef() + } return f.rootVFS2 } diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go index 3f34ee0db..9545bb5ef 100644 --- a/pkg/sentry/kernel/ipc_namespace.go +++ b/pkg/sentry/kernel/ipc_namespace.go @@ -41,7 +41,7 @@ func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace { semaphores: semaphore.NewRegistry(userNS), shms: shm.NewRegistry(userNS), } - ns.EnableLeakCheck() + ns.InitRefs() return ns } @@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry { return i.shms } -// DecRef implements refs_vfs2.RefCounter.DecRef. +// DecRef implements refsvfs2.RefCounter.DecRef. func (i *IPCNamespace) DecRef(ctx context.Context) { i.IPCNamespaceRefs.DecRef(func() { i.shms.Release(ctx) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0eb2bf7bd..b8627a54f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -214,9 +214,11 @@ type Kernel struct { // netlinkPorts manages allocation of netlink socket port IDs. netlinkPorts *port.Manager - // saveErr is the error causing the sandbox to exit during save, if - // any. It is protected by extMu. - saveErr error `state:"nosave"` + // saveStatus is nil if the sandbox has not been saved, errSaved or + // errAutoSaved if it has been saved successfully, or the error causing the + // sandbox to exit during save. + // It is protected by extMu. + saveStatus error `state:"nosave"` // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` @@ -430,9 +432,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w wire.Writer) error { +func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { saveStart := time.Now() - ctx := k.SupervisorContext() // Do not allow other Kernel methods to affect it while it's being saved. k.extMu.Lock() @@ -446,38 +447,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error { k.mf.StartEvictions() k.mf.WaitForEvictions() - // Flush write operations on open files so data reaches backing storage. - // This must come after MemoryFile eviction since eviction may cause file - // writes. - if err := k.tasks.flushWritesToFiles(ctx); err != nil { - return err - } + if VFS2Enabled { + // Discard unsavable mappings, such as those for host file descriptors. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } + + // Prepare filesystems for saving. This must be done after + // invalidateUnsavableMappings(), since dropping memory mappings may + // affect filesystem state (e.g. page cache reference counts). + if err := k.vfs.PrepareSave(ctx); err != nil { + return err + } + } else { + // Flush cached file writes to backing storage. This must come after + // MemoryFile eviction since eviction may cause file writes. + if err := k.flushWritesToFiles(ctx); err != nil { + return err + } - // Remove all epoll waiter objects from underlying wait queues. - // NOTE: for programs to resume execution in future snapshot scenarios, - // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters(ctx) + // Remove all epoll waiter objects from underlying wait queues. + // NOTE: for programs to resume execution in future snapshot scenarios, + // we will need to re-establish these waiter objects after saving. + k.tasks.unregisterEpollWaiters(ctx) - // Clear the dirent cache before saving because Dirents must be Loaded in a - // particular order (parents before children), and Loading dirents from a cache - // breaks that order. - if err := k.flushMountSourceRefs(ctx); err != nil { - return err - } + // Clear the dirent cache before saving because Dirents must be Loaded in a + // particular order (parents before children), and Loading dirents from a cache + // breaks that order. + if err := k.flushMountSourceRefs(ctx); err != nil { + return err + } - // Ensure that all inode and mount release operations have completed. - fs.AsyncBarrier() + // Ensure that all inode and mount release operations have completed. + fs.AsyncBarrier() - // Once all fs work has completed (flushed references have all been released), - // reset mount mappings. This allows individual mounts to save how inodes map - // to filesystem resources. Without this, fs.Inodes cannot be restored. - fs.SaveInodeMappings() + // Once all fs work has completed (flushed references have all been released), + // reset mount mappings. This allows individual mounts to save how inodes map + // to filesystem resources. Without this, fs.Inodes cannot be restored. + fs.SaveInodeMappings() - // Discard unsavable mappings, such as those for host file descriptors. - // This must be done after waiting for "asynchronous fs work", which - // includes async I/O that may touch application memory. - if err := k.invalidateUnsavableMappings(ctx); err != nil { - return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + // Discard unsavable mappings, such as those for host file descriptors. + // This must be done after waiting for "asynchronous fs work", which + // includes async I/O that may touch application memory. + // + // TODO(gvisor.dev/issue/1624): This rationale is believed to be + // obsolete since AIO callbacks are now waited-for by Kernel.Pause(), + // but this order is conservatively retained for VFS1. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } } // Save the CPUID FeatureSet before the rest of the kernel so we can @@ -486,14 +504,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { + if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - stats, err := state.Save(k.SupervisorContext(), w, k) + stats, err := state.Save(ctx, w, k) if err != nil { return err } @@ -502,7 +520,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { + if err := k.mf.SaveTo(ctx, w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -514,11 +532,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. +// +// Preconditions: !VFS2Enabled. func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { - if VFS2Enabled { - return nil // Not relevant. - } - // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -561,13 +577,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi return err } -func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return nil - } - - return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { +// Preconditions: !VFS2Enabled. +func (k *Kernel) flushWritesToFiles(ctx context.Context) error { + return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil } @@ -589,37 +601,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { }) } -// Preconditions: The kernel must be paused. -func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { - invalidated := make(map[*mm.MemoryManager]struct{}) - k.tasks.mu.RLock() - defer k.tasks.mu.RUnlock() - for t := range k.tasks.Root.tids { - // We can skip locking Task.mu here since the kernel is paused. - if mm := t.tc.MemoryManager; mm != nil { - if _, ok := invalidated[mm]; !ok { - if err := mm.InvalidateUnsavable(ctx); err != nil { - return err - } - invalidated[mm] = struct{}{} - } - } - // I really wish we just had a sync.Map of all MMs... - if r, ok := t.runState.(*runSyscallAfterExecStop); ok { - if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { - return err - } - } - } - return nil -} - +// Preconditions: !VFS2Enabled. func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return - } - ts.mu.RLock() defer ts.mu.RUnlock() @@ -644,8 +627,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { } } +// Preconditions: The kernel must be paused. +func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { + invalidated := make(map[*mm.MemoryManager]struct{}) + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + for t := range k.tasks.Root.tids { + // We can skip locking Task.mu here since the kernel is paused. + if mm := t.image.MemoryManager; mm != nil { + if _, ok := invalidated[mm]; !ok { + if err := mm.InvalidateUnsavable(ctx); err != nil { + return err + } + invalidated[mm] = struct{}{} + } + } + // I really wish we just had a sync.Map of all MMs... + if r, ok := t.runState.(*runSyscallAfterExecStop); ok { + if err := r.image.MemoryManager.InvalidateUnsavable(ctx); err != nil { + return err + } + } + } + return nil +} + // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -656,7 +664,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { + if _, err := state.Load(ctx, r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -671,7 +679,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the kernel state. kernelStart := time.Now() - stats, err := state.Load(k.SupervisorContext(), r, k) + stats, err := state.Load(ctx, r, k) if err != nil { return err } @@ -684,7 +692,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { + if err := k.mf.LoadFrom(ctx, r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -696,11 +704,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock net.Resume() } - // Ensure that all pending asynchronous work is complete: - // - namedpipe opening - // - inode file opening - if err := fs.AsyncErrorBarrier(); err != nil { - return err + if VFS2Enabled { + if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil { + return err + } + } else { + // Ensure that all pending asynchronous work is complete: + // - namedpipe opening + // - inode file opening + if err := fs.AsyncErrorBarrier(); err != nil { + return err + } } tcpip.AsyncLoading.Wait() @@ -1005,7 +1019,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, Features: k.featureSet, } - tc, se := k.LoadTaskImage(ctx, loadArgs) + image, se := k.LoadTaskImage(ctx, loadArgs) if se != nil { return nil, 0, errors.New(se.String()) } @@ -1018,7 +1032,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, config := &TaskConfig{ Kernel: k, ThreadGroup: tg, - TaskContext: tc, + TaskImage: image, FSContext: fsContext, FDTable: args.FDTable, Credentials: args.Credentials, @@ -1034,7 +1048,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, if err != nil { return nil, 0, err } - t.traceExecEvent(tc) // Simulate exec for tracing. + t.traceExecEvent(image) // Simulate exec for tracing. // Success. cu.Release() @@ -1347,6 +1361,13 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error { // not have meaningful trace data. Rebuilding here ensures that we can do so // after tracing has been enabled. func (k *Kernel) RebuildTraceContexts() { + // We need to pause all task goroutines because Task.rebuildTraceContext() + // replaces Task.traceContext and Task.traceTask, which are + // task-goroutine-exclusive (i.e. the task goroutine assumes that it can + // access them without synchronization) for performance. + k.Pause() + defer k.Unpause() + k.extMu.Lock() defer k.extMu.Unlock() k.tasks.mu.RLock() @@ -1462,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager { return k.netlinkPorts } -// SaveError returns the sandbox error that caused the kernel to exit during -// save. -func (k *Kernel) SaveError() error { +var ( + errSaved = errors.New("sandbox has been successfully saved") + errAutoSaved = errors.New("sandbox has been successfully auto-saved") +) + +// SaveStatus returns the sandbox save status. If it was saved successfully, +// autosaved indicates whether save was triggered by autosave. If it was not +// saved successfully, err indicates the sandbox error that caused the kernel to +// exit during save. +func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) { k.extMu.Lock() defer k.extMu.Unlock() - return k.saveErr + switch k.saveStatus { + case nil: + return false, false, nil + case errSaved: + return true, false, nil + case errAutoSaved: + return true, true, nil + default: + return false, false, k.saveStatus + } +} + +// SetSaveSuccess sets the flag indicating that save completed successfully, if +// no status was already set. +func (k *Kernel) SetSaveSuccess(autosave bool) { + k.extMu.Lock() + defer k.extMu.Unlock() + if k.saveStatus == nil { + if autosave { + k.saveStatus = errAutoSaved + } else { + k.saveStatus = errSaved + } + } } // SetSaveError sets the sandbox error that caused the kernel to exit during @@ -1475,8 +1526,8 @@ func (k *Kernel) SaveError() error { func (k *Kernel) SetSaveError(err error) { k.extMu.Lock() defer k.extMu.Unlock() - if k.saveErr == nil { - k.saveErr = err + if k.saveStatus == nil { + k.saveStatus = err } } diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index ce0db5583..d6fb0fdb8 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) type sleeper struct { @@ -66,7 +65,8 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag d := fs.NewDirent(ctx, inode, "pipe") file, err := n.GetFile(ctx, d, flags) if err != nil { - t.Fatalf("open with flags %+v failed: %v", flags, err) + t.Errorf("open with flags %+v failed: %v", flags, err) + return nil, err } if doneChan != nil { doneChan <- struct{}{} @@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs. } func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(true, DefaultPipeSize, usermem.PageSize) + return NewPipe(true, DefaultPipeSize) } func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(false, DefaultPipeSize, usermem.PageSize) + return NewPipe(false, DefaultPipeSize) } // assertRecvBlocks ensures that a recv attempt on c blocks for at least diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 67beb0ad6..b989e14c7 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -26,18 +26,27 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) const ( // MinimumPipeSize is a hard limit of the minimum size of a pipe. - MinimumPipeSize = 64 << 10 + // It corresponds to fs/pipe.c:pipe_min_size. + MinimumPipeSize = usermem.PageSize + + // MaximumPipeSize is a hard limit on the maximum size of a pipe. + // It corresponds to fs/pipe.c:pipe_max_size. + MaximumPipeSize = 1048576 // DefaultPipeSize is the system-wide default size of a pipe in bytes. - DefaultPipeSize = MinimumPipeSize + // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS. + DefaultPipeSize = 16 * usermem.PageSize - // MaximumPipeSize is a hard limit on the maximum size of a pipe. - MaximumPipeSize = 8 << 20 + // atomicIOBytes is the maximum number of bytes that the pipe will + // guarantee atomic reads or writes atomically. + // It corresponds to limits.h:PIPE_BUF. + atomicIOBytes = 4096 ) // Pipe is an encapsulation of a platform-independent pipe. @@ -53,12 +62,6 @@ type Pipe struct { // This value is immutable. isNamed bool - // atomicIOBytes is the maximum number of bytes that the pipe will - // guarantee atomic reads or writes atomically. - // - // This value is immutable. - atomicIOBytes int64 - // The number of active readers for this pipe. // // Access atomically. @@ -94,47 +97,34 @@ type Pipe struct { // NewPipe initializes and returns a pipe. // -// N.B. The size and atomicIOBytes will be bounded. -func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { +// N.B. The size will be bounded. +func NewPipe(isNamed bool, sizeBytes int64) *Pipe { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } var p Pipe - initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + initPipe(&p, isNamed, sizeBytes) return &p } -func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { +func initPipe(pipe *Pipe, isNamed bool, sizeBytes int64) { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } pipe.isNamed = isNamed pipe.max = sizeBytes - pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects // representing the read and write ends of the pipe. -func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes) +func NewConnectedPipe(ctx context.Context, sizeBytes int64) (*fs.File, *fs.File) { + p := NewPipe(false /* isNamed */, sizeBytes) // Build an fs.Dirent for the pipe which will be shared by both // returned files. @@ -264,7 +254,7 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { wanted := ops.left() avail := p.max - p.view.Size() if wanted > avail { - if wanted <= p.atomicIOBytes { + if wanted <= atomicIOBytes { return 0, syserror.ErrWouldBlock } ops.limit(avail) diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index fe97e9800..3dd739080 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -26,7 +26,7 @@ import ( func TestPipeRW(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -46,7 +46,7 @@ func TestPipeRW(t *testing.T) { func TestPipeReadBlock(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -61,7 +61,7 @@ func TestPipeWriteBlock(t *testing.T) { const capacity = MinimumPipeSize ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) + r, w := NewConnectedPipe(ctx, capacity) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -76,7 +76,7 @@ func TestPipeWriteUntilEnd(t *testing.T) { const atomicIOBytes = 2 ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) + r, w := NewConnectedPipe(ctx, atomicIOBytes) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -116,7 +116,8 @@ func TestPipeWriteUntilEnd(t *testing.T) { } } if err != nil { - t.Fatalf("Readv: got unexpected error %v", err) + t.Errorf("Readv: got unexpected error %v", err) + return } } }() diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 1a152142b..7b23cbe86 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -33,6 +33,8 @@ import ( // VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should // not be copied. +// +// +stateify savable type VFSPipe struct { // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -52,9 +54,9 @@ type VFSPipe struct { } // NewVFSPipe returns an initialized VFSPipe. -func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { +func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe { var vp VFSPipe - initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes) + initPipe(&vp.pipe, isNamed, sizeBytes) return &vp } @@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements // non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to // other FileDescriptions for splice(2) and tee(2). +// +// +stateify savable type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1145faf13..cef58a590 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) { Signo: int32(linux.SIGTRAP), Code: code, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) if t.beginPtraceStopLocked() { tracer := t.Tracer() tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP)) @@ -1000,7 +1000,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { // at the address specified by the data parameter, and the return value // is the error flag." - ptrace(2) word := t.Arch().Native(0) - if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { + if _, err := word.CopyIn(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { return err } _, err := word.CopyOut(t, data) @@ -1008,7 +1008,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA: word := t.Arch().Native(uintptr(data)) - _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr) + _, err := word.CopyOut(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr) return err case linux.PTRACE_GETREGSET: diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index 387edfa91..60917e7d3 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -106,7 +106,7 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 { data := linux.SeccompData{ Nr: sysno, - Arch: t.tc.st.AuditNumber, + Arch: t.image.st.AuditNumber, InstructionPointer: uint64(ip), } // data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index c00fa1138..3dd3953b3 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -29,17 +29,17 @@ import ( ) const ( - valueMax = 32767 // SEMVMX + // Maximum semaphore value. + valueMax = linux.SEMVMX - // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL). - semaphoresMax = 32000 + // Maximum number of semaphore sets. + setsMax = linux.SEMMNI - // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI). - setsMax = 32000 + // Maximum number of semaphroes in a semaphore set. + semsMax = linux.SEMMSL - // semaphoresTotalMax is "system-wide limit on the number of semaphores" - // (SEMMNS = SEMMNI*SEMMSL). - semaphoresTotalMax = 1024000000 + // Maximum number of semaphores in all semaphroe sets. + semsTotalMax = linux.SEMMNS ) // Registry maintains a set of semaphores that can be found by key or ID. @@ -52,6 +52,9 @@ type Registry struct { mu sync.Mutex `state:"nosave"` semaphores map[int32]*Set lastIDUsed int32 + // indexes maintains a mapping between a set's index in virtual array and + // its identifier. + indexes map[int32]int32 } // Set represents a set of semaphores that can be operated atomically. @@ -103,6 +106,7 @@ type waiter struct { waiterEntry // value represents how much resource the waiter needs to wake up. + // The value is either 0 or negative. value int16 ch chan struct{} } @@ -112,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { return &Registry{ userNS: userNS, semaphores: make(map[int32]*Set), + indexes: make(map[int32]int32), } } @@ -121,7 +126,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { // be found. If exclusive is true, it fails if a set with the same key already // exists. func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) { - if nsems < 0 || nsems > semaphoresMax { + if nsems < 0 || nsems > semsMax { return nil, syserror.EINVAL } @@ -162,10 +167,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu } // Apply system limits. + // + // Map semaphores and map indexes in a registry are of the same size, + // check map semaphores only here for the system limit. if len(r.semaphores) >= setsMax { return nil, syserror.EINVAL } - if r.totalSems() > int(semaphoresTotalMax-nsems) { + if r.totalSems() > int(semsTotalMax-nsems) { return nil, syserror.EINVAL } @@ -175,6 +183,39 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu return r.newSet(ctx, key, owner, owner, perms, nsems) } +// IPCInfo returns information about system-wide semaphore limits and parameters. +func (r *Registry) IPCInfo() *linux.SemInfo { + return &linux.SemInfo{ + SemMap: linux.SEMMAP, + SemMni: linux.SEMMNI, + SemMns: linux.SEMMNS, + SemMnu: linux.SEMMNU, + SemMsl: linux.SEMMSL, + SemOpm: linux.SEMOPM, + SemUme: linux.SEMUME, + SemUsz: 0, // SemUsz not supported. + SemVmx: linux.SEMVMX, + SemAem: linux.SEMAEM, + } +} + +// HighestIndex returns the index of the highest used entry in +// the kernel's array. +func (r *Registry) HighestIndex() int32 { + r.mu.Lock() + defer r.mu.Unlock() + + // By default, highest used index is 0 even though + // there is no semaphroe set. + var highestIndex int32 + for index := range r.indexes { + if index > highestIndex { + highestIndex = index + } + } + return highestIndex +} + // RemoveID removes set with give 'id' from the registry and marks the set as // dead. All waiters will be awakened and fail. func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { @@ -185,6 +226,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { if set == nil { return syserror.EINVAL } + index, found := r.findIndexByID(id) + if !found { + // Inconsistent state. + panic(fmt.Sprintf("unable to find an index for ID: %d", id)) + } set.mu.Lock() defer set.mu.Unlock() @@ -196,6 +242,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { } delete(r.semaphores, set.ID) + delete(r.indexes, index) set.destroy() return nil } @@ -219,6 +266,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File continue } if r.semaphores[id] == nil { + index, found := r.findFirstAvailableIndex() + if !found { + panic("unable to find an available index") + } + r.indexes[index] = id r.lastIDUsed = id r.semaphores[id] = set set.ID = id @@ -246,6 +298,24 @@ func (r *Registry) findByKey(key int32) *Set { return nil } +func (r *Registry) findIndexByID(id int32) (int32, bool) { + for k, v := range r.indexes { + if v == id { + return k, true + } + } + return 0, false +} + +func (r *Registry) findFirstAvailableIndex() (int32, bool) { + for index := int32(0); index < setsMax; index++ { + if _, present := r.indexes[index]; !present { + return index, true + } + } + return 0, false +} + func (r *Registry) totalSems() int { totalSems := 0 for _, v := range r.semaphores { @@ -283,6 +353,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File return nil } +// GetStat extracts semid_ds information from the set. +func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // "The calling process must have read permission on the semaphore set." + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return nil, syserror.EACCES + } + + ds := &linux.SemidDS{ + SemPerm: linux.IPCPerm{ + Key: uint32(s.key), + UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)), + GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)), + CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)), + CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)), + Mode: uint16(s.perms.LinuxMode()), + Seq: 0, // IPC sequence not supported. + }, + SemOTime: s.opTime.TimeT(), + SemCTime: s.changeTime.TimeT(), + SemNSems: uint64(s.Size()), + } + return ds, nil +} + // SetVal overrides a semaphore value, waking up waiters as needed. func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error { if val < 0 || val > valueMax { @@ -320,7 +417,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti } for _, val := range vals { - if val < 0 || val > valueMax { + if val > valueMax { return syserror.ERANGE } } @@ -396,6 +493,42 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) { return sem.pid, nil } +func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // The calling process must have read permission on the semaphore set. + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return 0, syserror.EACCES + } + + sem := s.findSem(num) + if sem == nil { + return 0, syserror.ERANGE + } + var cnt uint16 + for w := sem.waiters.Front(); w != nil; w = w.Next() { + if pred(w) { + cnt++ + } + } + return cnt, nil +} + +// CountZeroWaiters returns number of waiters waiting for the sem's value to increase. +func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value == 0 + }) +} + +// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero. +func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value < 0 + }) +} + // ExecuteOps attempts to execute a list of operations to the set. It only // succeeds when all operations can be applied. No changes are made if it fails. // @@ -548,11 +681,18 @@ func (s *Set) destroy() { } } +func abs(val int16) int16 { + if val < 0 { + return -val + } + return val +} + // wakeWaiters goes over all waiters and checks which of them can be notified. func (s *sem) wakeWaiters() { // Note that this will release all waiters waiting for 0 too. for w := s.waiters.Front(); w != nil; { - if s.value < w.value { + if s.value < abs(w.value) { // Still blocked, skip it. w = w.Next() continue diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index df5c8421b..0cd9e2533 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -295,7 +295,7 @@ func (tg *ThreadGroup) createSession() error { id: SessionID(id), leader: tg, } - s.EnableLeakCheck() + s.InitRefs() // Create a new ProcessGroup, belonging to that Session. // This also has a single reference (assigned below). @@ -309,7 +309,7 @@ func (tg *ThreadGroup) createSession() error { session: s, ancestors: 0, } - pg.refs.EnableLeakCheck() + pg.refs.InitRefs() // Tie them and return the result. s.processGroups.PushBack(pg) @@ -395,7 +395,7 @@ func (tg *ThreadGroup) CreateProcessGroup() error { originator: tg, session: tg.processGroup.session, } - pg.refs.EnableLeakCheck() + pg.refs.InitRefs() if tg.leader.parent != nil && tg.leader.parent.tg.processGroup.session == pg.session { pg.ancestors++ @@ -477,20 +477,20 @@ func (tg *ThreadGroup) Session() *Session { // // If this group isn't visible in this namespace, zero will be returned. It is // the callers responsibility to check that before using this function. -func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sids[s] +func (ns *PIDNamespace) IDOfSession(s *Session) SessionID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sids[s] } // SessionWithID returns the Session with the given ID in the PID namespace ns, // or nil if that given ID is not defined in this namespace. // // A reference is not taken on the session. -func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sessions[id] +func (ns *PIDNamespace) SessionWithID(id SessionID) *Session { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sessions[id] } // ProcessGroup returns the ThreadGroup's ProcessGroup. @@ -505,18 +505,18 @@ func (tg *ThreadGroup) ProcessGroup() *ProcessGroup { // IDOfProcessGroup returns the process group assigned to pg in PID namespace ns. // // The same constraints apply as IDOfSession. -func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.pgids[pg] +func (ns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.pgids[pg] } // ProcessGroupWithID returns the ProcessGroup with the given ID in the PID // namespace ns, or nil if that given ID is not defined in this namespace. // // A reference is not taken on the process group. -func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.processGroups[id] +func (ns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.processGroups[id] } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index f8a382fd8..073e14507 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -6,9 +6,12 @@ package(licenses = ["notice"]) go_template_instance( name = "shm_refs", out = "shm_refs.go", + consts = { + "enableLogging": "true", + }, package = "shm", prefix = "Shm", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Shm", }, @@ -27,7 +30,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index ebbebf46b..92d60ba78 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -251,7 +251,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi creatorPID: pid, changeTime: ktime.NowFromContext(ctx), } - shm.EnableLeakCheck() + shm.InitRefs() // Find the next available ID. for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ { diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go index e8cce37d0..2488ae7d5 100644 --- a/pkg/sentry/kernel/signal.go +++ b/pkg/sentry/kernel/signal.go @@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 78f718cfe..884966120 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go index 90f890495..0b17a562e 100644 --- a/pkg/sentry/kernel/syscalls_state.go +++ b/pkg/sentry/kernel/syscalls_state.go @@ -30,18 +30,18 @@ type syscallTableInfo struct { } // saveSt saves the SyscallTable. -func (tc *TaskContext) saveSt() syscallTableInfo { +func (image *TaskImage) saveSt() syscallTableInfo { return syscallTableInfo{ - OS: tc.st.OS, - Arch: tc.st.Arch, + OS: image.st.OS, + Arch: image.st.Arch, } } // loadSt loads the SyscallTable. -func (tc *TaskContext) loadSt(sti syscallTableInfo) { +func (image *TaskImage) loadSt(sti syscallTableInfo) { st, ok := LookupSyscallTable(sti.OS, sti.Arch) if !ok { panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch)) } - tc.st = st // Save the table reference. + image.st = st // Save the table reference. } diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index a83ce219c..3fee7aa68 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -75,6 +75,12 @@ func (s *syslog) Log() []byte { "Checking naughty and nice process list...", // Check it up to twice. "Granting licence to kill(2)...", // British spelling for British movie. "Letting the watchdogs out...", + "Conjuring /dev/null black hole...", + "Adversarially training Redcode AI...", + "Singleplexing /dev/ptmx...", + "Recruiting cron-ies...", + "Verifying that no non-zero bytes made their way into /dev/zero...", + "Accelerating teletypewriter to 9600 baud...", } selectMessage := func() string { diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 037971393..c0ab53c94 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" - "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -29,11 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/limits" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -63,6 +58,12 @@ import ( type Task struct { taskNode + // goid is the task goroutine's ID. goid is owned by the task goroutine, + // but since it's used to detect cases where non-task goroutines + // incorrectly access state owned by, or exclusive to, the task goroutine, + // goid is always accessed using atomic memory operations. + goid int64 `state:"nosave"` + // runState is what the task goroutine is executing if it is not stopped. // If runState is nil, the task goroutine should exit or has exited. // runState is exclusive to the task goroutine. @@ -83,7 +84,7 @@ type Task struct { // taskWork is exclusive to the task goroutine. taskWork []TaskWorker - // haveSyscallReturn is true if tc.Arch().Return() represents a value + // haveSyscallReturn is true if image.Arch().Return() represents a value // returned by a syscall (or set by ptrace after a syscall). // // haveSyscallReturn is exclusive to the task goroutine. @@ -257,10 +258,10 @@ type Task struct { // mu protects some of the following fields. mu sync.Mutex `state:"nosave"` - // tc holds task data provided by the ELF loader. + // image holds task data provided by the ELF loader. // - // tc is protected by mu, and is owned by the task goroutine. - tc TaskContext + // image is protected by mu, and is owned by the task goroutine. + image TaskImage // fsContext is the task's filesystem context. // @@ -274,7 +275,7 @@ type Task struct { // If vforkParent is not nil, it is the task that created this task with // vfork() or clone(CLONE_VFORK), and should have its vforkStop ended when - // this TaskContext is released. + // this TaskImage is released. // // vforkParent is protected by the TaskSet mutex. vforkParent *Task @@ -641,64 +642,6 @@ func (t *Task) Kernel() *Kernel { return t.k } -// Value implements context.Context.Value. -// -// Preconditions: The caller must be running on the task goroutine (as implied -// by the requirements of context.Context). -func (t *Task) Value(key interface{}) interface{} { - switch key { - case CtxCanTrace: - return t.CanTrace - case CtxKernel: - return t.k - case CtxPIDNamespace: - return t.tg.pidns - case CtxUTSNamespace: - return t.utsns - case CtxIPCNamespace: - ipcns := t.IPCNamespace() - ipcns.IncRef() - return ipcns - case CtxTask: - return t - case auth.CtxCredentials: - return t.Credentials() - case context.CtxThreadGroupID: - return int32(t.ThreadGroup().ID()) - case fs.CtxRoot: - return t.fsContext.RootDirectory() - case vfs.CtxRoot: - return t.fsContext.RootDirectoryVFS2() - case vfs.CtxMountNamespace: - t.mountNamespaceVFS2.IncRef() - return t.mountNamespaceVFS2 - case fs.CtxDirentCacheLimiter: - return t.k.DirentCacheLimiter - case inet.CtxStack: - return t.NetworkContext() - case ktime.CtxRealtimeClock: - return t.k.RealtimeClock() - case limits.CtxLimits: - return t.tg.limits - case pgalloc.CtxMemoryFile: - return t.k.mf - case pgalloc.CtxMemoryFileProvider: - return t.k - case platform.CtxPlatform: - return t.k - case uniqueid.CtxGlobalUniqueID: - return t.k.UniqueID() - case uniqueid.CtxGlobalUniqueIDProvider: - return t.k - case uniqueid.CtxInotifyCookie: - return t.k.GenerateInotifyCookie() - case unimpl.CtxEvents: - return t.k - default: - return nil - } -} - // SetClearTID sets t's cleartid. // // Preconditions: The caller must be running on the task goroutine. @@ -751,12 +694,12 @@ func (t *Task) IsChrooted() bool { return root != realRoot } -// TaskContext returns t's TaskContext. +// TaskImage returns t's TaskImage. // // Precondition: The caller must be running on the task goroutine, or t.mu must // be locked. -func (t *Task) TaskContext() *TaskContext { - return &t.tc +func (t *Task) TaskImage() *TaskImage { + return &t.image } // FSContext returns t's FSContext. FSContext does not take an additional diff --git a/pkg/sentry/kernel/task_acct.go b/pkg/sentry/kernel/task_acct.go index 5f3e60fe8..e574997f7 100644 --- a/pkg/sentry/kernel/task_acct.go +++ b/pkg/sentry/kernel/task_acct.go @@ -136,14 +136,14 @@ func (tg *ThreadGroup) IOUsage() *usage.IO { func (t *Task) Name() string { t.mu.Lock() defer t.mu.Unlock() - return t.tc.Name + return t.image.Name } // SetName changes t's name. func (t *Task) SetName(name string) { t.mu.Lock() defer t.mu.Unlock() - t.tc.Name = name + t.image.Name = name t.Debugf("Set thread name to %q", name) } diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 4a4a69ee2..9419f2e95 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -20,6 +20,7 @@ import ( "time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -32,6 +33,8 @@ import ( // // - An error which is nil if an event is received from C, ETIMEDOUT if the timeout // expired, and syserror.ErrInterrupted if t is interrupted. +// +// Preconditions: The caller must be running on the task goroutine. func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.Duration) (time.Duration, error) { if !haveTimeout { return timeout, t.block(C, nil) @@ -112,7 +115,14 @@ func (t *Task) Block(C <-chan struct{}) error { // block blocks a task on one of many events. // N.B. defer is too expensive to be used here. +// +// Preconditions: The caller must be running on the task goroutine. func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { + // This function is very hot; skip this check outside of +race builds. + if sync.RaceEnabled { + t.assertTaskGoroutine() + } + // Fast path if the request is already done. select { case <-C: @@ -156,33 +166,39 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error { } } -// SleepStart implements amutex.Sleeper.SleepStart. +// SleepStart implements context.ChannelSleeper.SleepStart. func (t *Task) SleepStart() <-chan struct{} { + t.assertTaskGoroutine() t.Deactivate() t.accountTaskGoroutineEnter(TaskGoroutineBlockedInterruptible) return t.interruptChan } -// SleepFinish implements amutex.Sleeper.SleepFinish. +// SleepFinish implements context.ChannelSleeper.SleepFinish. func (t *Task) SleepFinish(success bool) { if !success { - // The interrupted notification is consumed only at the top-level - // (Run). Therefore we attempt to reset the pending notification. - // This will also elide our next entry back into the task, so we - // will process signals, state changes, etc. + // Our caller received from t.interruptChan; we need to re-send to it + // to ensure that t.interrupted() is still true. t.interruptSelf() } t.accountTaskGoroutineLeave(TaskGoroutineBlockedInterruptible) t.Activate() } -// Interrupted implements amutex.Sleeper.Interrupted +// Interrupted implements context.ChannelSleeper.Interrupted. func (t *Task) Interrupted() bool { - return len(t.interruptChan) != 0 + if t.interrupted() { + return true + } + // Indicate that t's task goroutine is still responsive (i.e. reset the + // watchdog timer). + t.accountTaskGoroutineRunning() + return false } // UninterruptibleSleepStart implements context.Context.UninterruptibleSleepStart. func (t *Task) UninterruptibleSleepStart(deactivate bool) { + t.assertTaskGoroutine() if deactivate { t.Deactivate() } @@ -198,13 +214,17 @@ func (t *Task) UninterruptibleSleepFinish(activate bool) { } // interrupted returns true if interrupt or interruptSelf has been called at -// least once since the last call to interrupted. +// least once since the last call to unsetInterrupted. func (t *Task) interrupted() bool { + return len(t.interruptChan) != 0 +} + +// unsetInterrupted causes interrupted to return false until the next call to +// interrupt or interruptSelf. +func (t *Task) unsetInterrupted() { select { case <-t.interruptChan: - return true default: - return false } } @@ -220,9 +240,7 @@ func (t *Task) interrupt() { func (t *Task) interruptSelf() { select { case t.interruptChan <- struct{}{}: - t.Debugf("Interrupt queued") default: - t.Debugf("Dropping duplicate interrupt") } // platform.Context.Interrupt() is unnecessary since a task goroutine // calling interruptSelf() cannot also be blocked in diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 682080c14..f305e69c0 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -115,7 +115,7 @@ type CloneOptions struct { ParentTID usermem.Addr // If Vfork is true, place the parent in vforkStop until the cloned task - // releases its TaskContext. + // releases its TaskImage. Vfork bool // If Untraced is true, do not report PTRACE_EVENT_CLONE/FORK/VFORK for @@ -226,20 +226,20 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { }) } - tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace) + image, err := t.image.Fork(t, t.k, !opts.NewAddressSpace) if err != nil { return 0, nil, err } cu.Add(func() { - tc.release() + image.release() }) // clone() returns 0 in the child. - tc.Arch.SetReturn(0) + image.Arch.SetReturn(0) if opts.Stack != 0 { - tc.Arch.SetStack(uintptr(opts.Stack)) + image.Arch.SetStack(uintptr(opts.Stack)) } if opts.SetTLS { - if !tc.Arch.SetTLS(uintptr(opts.TLS)) { + if !image.Arch.SetTLS(uintptr(opts.TLS)) { return 0, nil, syserror.EPERM } } @@ -288,7 +288,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { Kernel: t.k, ThreadGroup: tg, SignalMask: t.SignalMask(), - TaskContext: tc, + TaskImage: image, FSContext: fsContext, FDTable: fdTable, Credentials: creds, @@ -355,7 +355,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } if opts.ChildSetTID { ctid := nt.ThreadID() - ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) + ctid.CopyOut(nt.CopyContext(t, usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) } ntid := t.tg.pidns.IDOfTask(nt) if opts.ParentSetTID { diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index d1136461a..70b0699dc 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,159 +15,175 @@ package kernel import ( - "fmt" + "time" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/kernel/futex" - "gvisor.dev/gvisor/pkg/sentry/loader" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/unimpl" + "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" ) -var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) - -// Auxmap contains miscellaneous data for the task. -type Auxmap map[string]interface{} - -// TaskContext is the subset of a task's data that is provided by the loader. -// -// +stateify savable -type TaskContext struct { - // Name is the thread name set by the prctl(PR_SET_NAME) system call. - Name string - - // Arch is the architecture-specific context (registers, etc.) - Arch arch.Context - - // MemoryManager is the task's address space. - MemoryManager *mm.MemoryManager +// Deadline implements context.Context.Deadline. +func (t *Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} - // fu implements futexes in the address space. - fu *futex.Manager +// Done implements context.Context.Done. +func (t *Task) Done() <-chan struct{} { + return nil +} - // st is the task's syscall table. - st *SyscallTable `state:".(syscallTableInfo)"` +// Err implements context.Context.Err. +func (t *Task) Err() error { + return nil } -// release releases all resources held by the TaskContext. release is called by -// the task when it execs into a new TaskContext or exits. -func (tc *TaskContext) release() { - // Nil out pointers so that if the task is saved after release, it doesn't - // follow the pointers to possibly now-invalid objects. - if tc.MemoryManager != nil { - tc.MemoryManager.DecUsers(context.Background()) - tc.MemoryManager = nil +// Value implements context.Context.Value. +// +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) Value(key interface{}) interface{} { + // This function is very hot; skip this check outside of +race builds. + if sync.RaceEnabled { + t.assertTaskGoroutine() } - tc.fu = nil + return t.contextValue(key, true /* isTaskGoroutine */) } -// Fork returns a duplicate of tc. The copied TaskContext always has an -// independent arch.Context. If shareAddressSpace is true, the copied -// TaskContext shares an address space with the original; otherwise, the copied -// TaskContext has an independent address space that is initially a duplicate -// of the original's. -func (tc *TaskContext) Fork(ctx context.Context, k *Kernel, shareAddressSpace bool) (*TaskContext, error) { - newTC := &TaskContext{ - Name: tc.Name, - Arch: tc.Arch.Fork(), - st: tc.st, - } - if shareAddressSpace { - newTC.MemoryManager = tc.MemoryManager - if newTC.MemoryManager != nil { - if !newTC.MemoryManager.IncUsers() { - // Shouldn't be possible since tc.MemoryManager should be a - // counted user. - panic(fmt.Sprintf("TaskContext.Fork called with userless TaskContext.MemoryManager")) - } +func (t *Task) contextValue(key interface{}, isTaskGoroutine bool) interface{} { + switch key { + case CtxCanTrace: + return t.CanTrace + case CtxKernel: + return t.k + case CtxPIDNamespace: + return t.tg.pidns + case CtxUTSNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + return t.utsns + case CtxIPCNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + ipcns := t.ipcns + ipcns.IncRef() + return ipcns + case CtxTask: + return t + case auth.CtxCredentials: + return t.creds.Load() + case context.CtxThreadGroupID: + return int32(t.tg.ID()) + case fs.CtxRoot: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() + } + return t.fsContext.RootDirectory() + case vfs.CtxRoot: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() } - newTC.fu = tc.fu - } else { - newMM, err := tc.MemoryManager.Fork(ctx) - if err != nil { - return nil, err + return t.fsContext.RootDirectoryVFS2() + case vfs.CtxMountNamespace: + if !isTaskGoroutine { + t.mu.Lock() + defer t.mu.Unlock() } - newTC.MemoryManager = newMM - newTC.fu = k.futexes.Fork() + t.mountNamespaceVFS2.IncRef() + return t.mountNamespaceVFS2 + case fs.CtxDirentCacheLimiter: + return t.k.DirentCacheLimiter + case inet.CtxStack: + return t.NetworkContext() + case ktime.CtxRealtimeClock: + return t.k.RealtimeClock() + case limits.CtxLimits: + return t.tg.limits + case pgalloc.CtxMemoryFile: + return t.k.mf + case pgalloc.CtxMemoryFileProvider: + return t.k + case platform.CtxPlatform: + return t.k + case uniqueid.CtxGlobalUniqueID: + return t.k.UniqueID() + case uniqueid.CtxGlobalUniqueIDProvider: + return t.k + case uniqueid.CtxInotifyCookie: + return t.k.GenerateInotifyCookie() + case unimpl.CtxEvents: + return t.k + default: + return nil } - return newTC, nil } -// Arch returns t's arch.Context. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) Arch() arch.Context { - return t.tc.Arch +// taskAsyncContext implements context.Context for a goroutine that performs +// work on behalf of a Task, but is not the task goroutine. +type taskAsyncContext struct { + context.NoopSleeper + + t *Task } -// MemoryManager returns t's MemoryManager. MemoryManager does not take an -// additional reference on the returned MM. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) MemoryManager() *mm.MemoryManager { - return t.tc.MemoryManager +// AsyncContext returns a context.Context representing t. The returned +// context.Context is intended for use by goroutines other than t's task +// goroutine; for example, signal delivery to t will not interrupt goroutines +// that are blocking using the returned context.Context. +func (t *Task) AsyncContext() context.Context { + return taskAsyncContext{t: t} } -// SyscallTable returns t's syscall table. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) SyscallTable() *SyscallTable { - return t.tc.st +// Debugf implements log.Logger.Debugf. +func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) { + ctx.t.Debugf(format, v...) } -// Stack returns the userspace stack. -// -// Preconditions: The caller must be running on the task goroutine, or t.mu -// must be locked. -func (t *Task) Stack() *arch.Stack { - return &arch.Stack{ - Arch: t.Arch(), - IO: t.MemoryManager(), - Bottom: usermem.Addr(t.Arch().Stack()), - } +// Infof implements log.Logger.Infof. +func (ctx taskAsyncContext) Infof(format string, v ...interface{}) { + ctx.t.Infof(format, v...) } -// LoadTaskImage loads a specified file into a new TaskContext. -// -// args.MemoryManager does not need to be set by the caller. -func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) { - // If File is not nil, we should load that instead of resolving Filename. - if args.File != nil { - args.Filename = args.File.PathnameWithDeleted(ctx) - } +// Warningf implements log.Logger.Warningf. +func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) { + ctx.t.Warningf(format, v...) +} + +// IsLogging implements log.Logger.IsLogging. +func (ctx taskAsyncContext) IsLogging(level log.Level) bool { + return ctx.t.IsLogging(level) +} - // Prepare a new user address space to load into. - m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) - defer m.DecUsers(ctx) - args.MemoryManager = m +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return time.Time{}, false +} - os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) - if err != nil { - return nil, err - } +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return nil +} - // Lookup our new syscall table. - st, ok := LookupSyscallTable(os, ac.Arch()) - if !ok { - // No syscall table found. This means that the ELF binary does not match - // the architecture. - return nil, errNoSyscalls - } +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return nil +} - if !m.IncUsers() { - panic("Failed to increment users count on new MM") - } - return &TaskContext{ - Name: name, - Arch: ac, - MemoryManager: m, - fu: k.futexes.Fork(), - st: st, - }, nil +// Value implements context.Context.Value. +func (ctx taskAsyncContext) Value(key interface{}) interface{} { + return ctx.t.contextValue(key, false /* isTaskGoroutine */) } diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 412d471d3..d9897e802 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -83,11 +83,12 @@ type execStop struct{} func (*execStop) Killable() bool { return true } // Execve implements the execve(2) syscall by killing all other tasks in its -// thread group and switching to newTC. Execve always takes ownership of newTC. +// thread group and switching to newImage. Execve always takes ownership of +// newImage. // // Preconditions: The caller must be running Task.doSyscallInvoke on the task // goroutine. -func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { +func (t *Task) Execve(newImage *TaskImage) (*SyscallControl, error) { t.tg.pidns.owner.mu.Lock() defer t.tg.pidns.owner.mu.Unlock() t.tg.signalHandlers.mu.Lock() @@ -96,7 +97,7 @@ func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { if t.tg.exiting || t.tg.execing != nil { // We lost to a racing group-exit, kill, or exec from another thread // and should just exit. - newTC.release() + newImage.release() return nil, syserror.EINTR } @@ -118,7 +119,7 @@ func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { t.beginInternalStopLocked((*execStop)(nil)) } - return &SyscallControl{next: &runSyscallAfterExecStop{newTC}, ignoreReturn: true}, nil + return &SyscallControl{next: &runSyscallAfterExecStop{newImage}, ignoreReturn: true}, nil } // The runSyscallAfterExecStop state continues execve(2) after all siblings of @@ -126,16 +127,16 @@ func (t *Task) Execve(newTC *TaskContext) (*SyscallControl, error) { // // +stateify savable type runSyscallAfterExecStop struct { - tc *TaskContext + image *TaskImage } func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { - t.traceExecEvent(r.tc) + t.traceExecEvent(r.image) t.tg.pidns.owner.mu.Lock() t.tg.execing = nil if t.killed() { t.tg.pidns.owner.mu.Unlock() - r.tc.release() + r.image.release() return (*runInterrupt)(nil) } // We are the thread group leader now. Save our old thread ID for @@ -214,7 +215,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // executables (set-user/group-ID bits and file capabilities). This // allows us to unconditionally enable user dumpability on the new mm. // See fs/exec.c:setup_new_exec. - r.tc.MemoryManager.SetDumpability(mm.UserDumpable) + r.image.MemoryManager.SetDumpability(mm.UserDumpable) // Switch to the new process. t.MemoryManager().Deactivate() @@ -222,8 +223,8 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // Update credentials to reflect the execve. This should precede switching // MMs to ensure that dumpability has been reset first, if needed. t.updateCredsForExecLocked() - t.tc.release() - t.tc = *r.tc + t.image.release() + t.image = *r.image t.mu.Unlock() t.unstopVforkParent() t.p.FullStateChanged() diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index ce7b9641d..16986244c 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -266,7 +266,7 @@ func (*runExitMain) execute(t *Task) taskRunState { t.updateRSSLocked() t.tg.pidns.owner.mu.Unlock() t.mu.Lock() - t.tc.release() + t.image.release() t.mu.Unlock() // Releasing the MM unblocks a blocked CLONE_VFORK parent. @@ -368,8 +368,8 @@ func (t *Task) exitChildren() { Signo: int32(sig), Code: arch.SignalInfoUser, } - siginfo.SetPid(int32(c.tg.pidns.tids[t])) - siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) + siginfo.SetPID(int32(c.tg.pidns.tids[t])) + siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) c.tg.signalHandlers.mu.Lock() c.sendSignalLocked(siginfo, true /* group */) c.tg.signalHandlers.mu.Unlock() @@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si info := &arch.SignalInfo{ Signo: int32(sig), } - info.SetPid(int32(receiver.tg.pidns.tids[t])) - info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.tids[t])) + info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) if t.exitStatus.Signaled() { info.Code = arch.CLD_KILLED info.SetStatus(int32(t.exitStatus.Signo)) diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index c80391475..195c7da9b 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -26,7 +26,7 @@ import ( // Preconditions: The caller must be running on the task goroutine, or t.mu // must be locked. func (t *Task) Futex() *futex.Manager { - return t.tc.fu + return t.image.fu } // SwapUint32 implements futex.Target.SwapUint32. diff --git a/pkg/sentry/kernel/task_image.go b/pkg/sentry/kernel/task_image.go new file mode 100644 index 000000000..ce5fbd299 --- /dev/null +++ b/pkg/sentry/kernel/task_image.go @@ -0,0 +1,173 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel/futex" + "gvisor.dev/gvisor/pkg/sentry/loader" + "gvisor.dev/gvisor/pkg/sentry/mm" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/usermem" +) + +var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC) + +// Auxmap contains miscellaneous data for the task. +type Auxmap map[string]interface{} + +// TaskImage is the subset of a task's data that is provided by the loader. +// +// +stateify savable +type TaskImage struct { + // Name is the thread name set by the prctl(PR_SET_NAME) system call. + Name string + + // Arch is the architecture-specific context (registers, etc.) + Arch arch.Context + + // MemoryManager is the task's address space. + MemoryManager *mm.MemoryManager + + // fu implements futexes in the address space. + fu *futex.Manager + + // st is the task's syscall table. + st *SyscallTable `state:".(syscallTableInfo)"` +} + +// release releases all resources held by the TaskImage. release is called by +// the task when it execs into a new TaskImage or exits. +func (image *TaskImage) release() { + // Nil out pointers so that if the task is saved after release, it doesn't + // follow the pointers to possibly now-invalid objects. + if image.MemoryManager != nil { + image.MemoryManager.DecUsers(context.Background()) + image.MemoryManager = nil + } + image.fu = nil +} + +// Fork returns a duplicate of image. The copied TaskImage always has an +// independent arch.Context. If shareAddressSpace is true, the copied +// TaskImage shares an address space with the original; otherwise, the copied +// TaskImage has an independent address space that is initially a duplicate +// of the original's. +func (image *TaskImage) Fork(ctx context.Context, k *Kernel, shareAddressSpace bool) (*TaskImage, error) { + newImage := &TaskImage{ + Name: image.Name, + Arch: image.Arch.Fork(), + st: image.st, + } + if shareAddressSpace { + newImage.MemoryManager = image.MemoryManager + if newImage.MemoryManager != nil { + if !newImage.MemoryManager.IncUsers() { + // Shouldn't be possible since image.MemoryManager should be a + // counted user. + panic(fmt.Sprintf("TaskImage.Fork called with userless TaskImage.MemoryManager")) + } + } + newImage.fu = image.fu + } else { + newMM, err := image.MemoryManager.Fork(ctx) + if err != nil { + return nil, err + } + newImage.MemoryManager = newMM + newImage.fu = k.futexes.Fork() + } + return newImage, nil +} + +// Arch returns t's arch.Context. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) Arch() arch.Context { + return t.image.Arch +} + +// MemoryManager returns t's MemoryManager. MemoryManager does not take an +// additional reference on the returned MM. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) MemoryManager() *mm.MemoryManager { + return t.image.MemoryManager +} + +// SyscallTable returns t's syscall table. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) SyscallTable() *SyscallTable { + return t.image.st +} + +// Stack returns the userspace stack. +// +// Preconditions: The caller must be running on the task goroutine, or t.mu +// must be locked. +func (t *Task) Stack() *arch.Stack { + return &arch.Stack{ + Arch: t.Arch(), + IO: t.MemoryManager(), + Bottom: usermem.Addr(t.Arch().Stack()), + } +} + +// LoadTaskImage loads a specified file into a new TaskImage. +// +// args.MemoryManager does not need to be set by the caller. +func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskImage, *syserr.Error) { + // If File is not nil, we should load that instead of resolving Filename. + if args.File != nil { + args.Filename = args.File.PathnameWithDeleted(ctx) + } + + // Prepare a new user address space to load into. + m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation) + defer m.DecUsers(ctx) + args.MemoryManager = m + + os, ac, name, err := loader.Load(ctx, args, k.extraAuxv, k.vdso) + if err != nil { + return nil, err + } + + // Lookup our new syscall table. + st, ok := LookupSyscallTable(os, ac.Arch()) + if !ok { + // No syscall table found. This means that the ELF binary does not match + // the architecture. + return nil, errNoSyscalls + } + + if !m.IncUsers() { + panic("Failed to increment users count on new MM") + } + return &TaskImage{ + Name: name, + Arch: ac, + MemoryManager: m, + fu: k.futexes.Fork(), + st: st, + }, nil +} diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go index d23cea802..c70e5e6ce 100644 --- a/pkg/sentry/kernel/task_log.go +++ b/pkg/sentry/kernel/task_log.go @@ -19,6 +19,7 @@ import ( "runtime/trace" "sort" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/usermem" ) @@ -215,7 +216,7 @@ func (t *Task) rebuildTraceContext(tid ThreadID) { // arbitrarily large (in general it won't be, especially for cases // where we're collecting a brief profile), so using the TID is a // reasonable compromise in this case. - t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid)) + t.traceContext, t.traceTask = trace.NewTask(context.Background(), fmt.Sprintf("tid:%d", tid)) } // traceCloneEvent is called when a new task is spawned. @@ -237,11 +238,11 @@ func (t *Task) traceExitEvent() { } // traceExecEvent is called when a task calls exec. -func (t *Task) traceExecEvent(tc *TaskContext) { +func (t *Task) traceExecEvent(image *TaskImage) { if !trace.IsEnabled() { return } - file := tc.MemoryManager.Executable() + file := image.MemoryManager.Executable() if file == nil { trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>") return diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 8dc3fec90..3ccecf4b6 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -16,11 +16,13 @@ package kernel import ( "bytes" + "fmt" "runtime" "runtime/trace" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/goid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/hostcpu" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -57,6 +59,8 @@ type taskRunState interface { // make it visible in stack dumps. A goroutine for a given task can be identified // searching for Task.run()'s argument value. func (t *Task) run(threadID uintptr) { + atomic.StoreInt64(&t.goid, goid.Get()) + // Construct t.blockingTimer here. We do this here because we can't // reconstruct t.blockingTimer during restore in Task.afterLoad(), because // kernel.timekeeper.SetClocks() hasn't been called yet. @@ -99,6 +103,9 @@ func (t *Task) run(threadID uintptr) { t.tg.pidns.owner.runningGoroutines.Done() t.p.Release() + // Deferring this store triggers a false positive in the race + // detector (https://github.com/golang/go/issues/42599). + atomic.StoreInt64(&t.goid, 0) // Keep argument alive because stack trace for dead variables may not be correct. runtime.KeepAlive(threadID) return @@ -317,7 +324,7 @@ func (app *runApp) execute(t *Task) taskRunState { // region. We should be able to easily identify // vsyscalls by having a <fault><syscall> pair. if at.Execute { - if sysno, ok := t.tc.st.LookupEmulate(addr); ok { + if sysno, ok := t.image.st.LookupEmulate(addr); ok { return t.doVsyscall(addr, sysno) } } @@ -375,6 +382,19 @@ func (app *runApp) execute(t *Task) taskRunState { } } +// assertTaskGoroutine panics if the caller is not running on t's task +// goroutine. +func (t *Task) assertTaskGoroutine() { + if got, want := goid.Get(), atomic.LoadInt64(&t.goid); got != want { + panic(fmt.Sprintf("running on goroutine %d (task goroutine for kernel.Task %p is %d)", got, t, want)) + } +} + +// GoroutineID returns the ID of t's task goroutine. +func (t *Task) GoroutineID() int64 { + return atomic.LoadInt64(&t.goid) +} + // waitGoroutineStoppedOrExited blocks until t's task goroutine stops or exits. func (t *Task) waitGoroutineStoppedOrExited() { t.goroutineStopped.Wait() diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 52c55d13d..9ba5f8d78 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -157,6 +157,18 @@ func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { t.goschedSeq.EndWrite() } +// Preconditions: The caller must be running on the task goroutine. +func (t *Task) accountTaskGoroutineRunning() { + now := t.k.CPUClockNow() + if t.gosched.State != TaskGoroutineRunningSys { + panic(fmt.Sprintf("Task goroutine in state %v (expected %v)", t.gosched.State, TaskGoroutineRunningSys)) + } + t.goschedSeq.BeginWrite() + t.gosched.SysTicks += now - t.gosched.Timestamp + t.gosched.Timestamp = now + t.goschedSeq.EndWrite() +} + // TaskGoroutineSchedInfo returns a copy of t's task goroutine scheduling info. // Most clients should use t.CPUStats() instead. func (t *Task) TaskGoroutineSchedInfo() TaskGoroutineSchedInfo { diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index ebdb83061..75af3af79 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -619,9 +619,6 @@ func (t *Task) setSignalMaskLocked(mask linux.SignalSet) { return } }) - // We have to re-issue the interrupt consumed by t.interrupted() since - // it might have been for a different reason. - t.interruptSelf() } // Conversely, if the new mask unblocks any signals that were blocked by @@ -917,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { Signo: int32(linux.SIGCHLD), Code: code, } - sigchld.SetPid(int32(t.tg.pidns.tids[target])) - sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + sigchld.SetPID(int32(t.tg.pidns.tids[target])) + sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) sigchld.SetStatus(status) // TODO(b/72102453): Set utime, stime. t.sendSignalLocked(sigchld, true /* group */) @@ -931,10 +928,10 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { type runInterrupt struct{} func (*runInterrupt) execute(t *Task) taskRunState { - // Interrupts are de-duplicated (if t is interrupted twice before - // t.interrupted() is called, t.interrupted() will only return true once), - // so early exits from this function must re-enter the runInterrupt state - // to check for more interrupt-signaled conditions. + // Interrupts are de-duplicated (t.unsetInterrupted() will undo the effect + // of all previous calls to t.interrupted() regardless of how many such + // calls there have been), so early exits from this function must re-enter + // the runInterrupt state to check for more interrupt-signaled conditions. t.tg.signalHandlers.mu.Lock() @@ -1025,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState { Signo: int32(sig), Code: t.ptraceCode, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } else { t.ptraceCode = int32(sig) t.ptraceSiginfo = nil @@ -1080,6 +1077,7 @@ func (*runInterrupt) execute(t *Task) taskRunState { return t.deliverSignal(info, act) } + t.unsetInterrupted() t.tg.signalHandlers.mu.Unlock() return (*runApp)(nil) } @@ -1116,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { if parent == nil { // Tracer has detached and t was created by Kernel.CreateProcess(). // Pretend the parent is in an ancestor PID + user namespace. - info.SetPid(0) - info.SetUid(int32(auth.OverflowUID)) + info.SetPID(0) + info.SetUID(int32(auth.OverflowUID)) } else { - info.SetPid(int32(t.tg.pidns.tids[parent])) - info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + info.SetPID(int32(t.tg.pidns.tids[parent])) + info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } } t.tg.signalHandlers.mu.Lock() diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 8e28230cc..36e1384f1 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -46,10 +46,10 @@ type TaskConfig struct { // SignalMask is the new task's initial signal mask. SignalMask linux.SignalSet - // TaskContext is the TaskContext of the new task. Ownership of the - // TaskContext is transferred to TaskSet.NewTask, whether or not it + // TaskImage is the TaskImage of the new task. Ownership of the + // TaskImage is transferred to TaskSet.NewTask, whether or not it // succeeds. - TaskContext *TaskContext + TaskImage *TaskImage // FSContext is the FSContext of the new task. A reference must be held on // FSContext, which is transferred to TaskSet.NewTask whether or not it @@ -105,7 +105,7 @@ type TaskConfig struct { func (ts *TaskSet) NewTask(ctx context.Context, cfg *TaskConfig) (*Task, error) { t, err := ts.newTask(cfg) if err != nil { - cfg.TaskContext.release() + cfg.TaskImage.release() cfg.FSContext.DecRef(ctx) cfg.FDTable.DecRef(ctx) cfg.IPCNamespace.DecRef(ctx) @@ -121,7 +121,7 @@ func (ts *TaskSet) NewTask(ctx context.Context, cfg *TaskConfig) (*Task, error) // of cfg if it succeeds. func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { tg := cfg.ThreadGroup - tc := cfg.TaskContext + image := cfg.TaskImage t := &Task{ taskNode: taskNode{ tg: tg, @@ -132,7 +132,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { interruptChan: make(chan struct{}, 1), signalMask: cfg.SignalMask, signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable}, - tc: *tc, + image: *image, fsContext: cfg.FSContext, fdTable: cfg.FDTable, p: cfg.Kernel.Platform.NewContext(), diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index ce134bf54..94dabbcd8 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,7 +18,8 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -281,29 +282,89 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp }, nil } -// copyContext implements marshal.CopyContext. It wraps a task to allow copying -// memory to and from the task memory with custom usermem.IOOpts. -type copyContext struct { - *Task +type taskCopyContext struct { + ctx context.Context + t *Task opts usermem.IOOpts } -// AsCopyContext wraps the task and returns it as CopyContext. -func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext { - return ©Context{t, opts} +// CopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. +func (t *Task) CopyContext(ctx context.Context, opts usermem.IOOpts) *taskCopyContext { + return &taskCopyContext{ + ctx: ctx, + t: t, + opts: opts, + } +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *taskCopyContext) CopyScratchBuffer(size int) []byte { + if ctxTask, ok := cc.ctx.(*Task); ok { + return ctxTask.CopyScratchBuffer(size) + } + return make([]byte, size) +} + +func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) { + cc.t.mu.Lock() + tmm := cc.t.MemoryManager() + cc.t.mu.Unlock() + if !tmm.IncUsers() { + return nil, syserror.EFAULT + } + return tmm, nil +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyIn(cc.ctx, addr, dst, cc.opts) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyOut(cc.ctx, addr, src, cc.opts) +} + +type ownTaskCopyContext struct { + t *Task + opts usermem.IOOpts +} + +// OwnCopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. The returned CopyContext may only be used by t's task +// goroutine. +// +// Since t already implements marshal.CopyContext, this is only needed to +// override the usermem.IOOpts used for the copy. +func (t *Task) OwnCopyContext(opts usermem.IOOpts) *ownTaskCopyContext { + return &ownTaskCopyContext{ + t: t, + opts: opts, + } } -// CopyInString copies a string in from the task's memory. -func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) { - return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts) +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte { + return cc.t.CopyScratchBuffer(size) } -// CopyInBytes copies task memory into dst from an IO context. -func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { - return t.MemoryManager().CopyIn(t, addr, dst, t.opts) +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts) } -// CopyOutBytes copies src into task memoryfrom an IO context. -func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { - return t.MemoryManager().CopyOut(t, addr, src, t.opts) +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts) } diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index 9bc452e67..9e5c2d26f 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -115,7 +115,7 @@ func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error { } if old != v.seq { - return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq) + return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d; application may hang or get incorrect time from the VDSO", old, v.seq) } v.seq = next diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 7fd77925f..49e21026e 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp // Translations must be contiguous and in increasing order of // Translation.Source. if i > 0 && ts[i-1].Source.End != t.Source.Start { - return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t) + return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t) } // At least part of each Translation must be required. if t.Source.Intersect(required).Length() == 0 { diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b4a47ccca..6dbeccfe2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -78,7 +78,7 @@ go_template_instance( out = "aio_mappable_refs.go", package = "mm", prefix = "aioMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "aioMappable", }, @@ -89,7 +89,7 @@ go_template_instance( out = "special_mappable_refs.go", package = "mm", prefix = "SpecialMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SpecialMappable", }, @@ -127,6 +127,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safecopy", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 7bf48cb2c..4c8cd38ed 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -252,7 +252,7 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) { return nil, err } m := aioMappable{mfp: mfp, fr: fr} - m.EnableLeakCheck() + m.InitRefs() return &m, nil } diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go index 3dabac1af..e8931922f 100644 --- a/pkg/sentry/mm/aio_context_state.go +++ b/pkg/sentry/mm/aio_context_state.go @@ -15,6 +15,6 @@ package mm // afterLoad is invoked by stateify. -func (a *AIOContext) afterLoad() { - a.requestReady = make(chan struct{}, 1) +func (ctx *AIOContext) afterLoad() { + ctx.requestReady = make(chan struct{}, 1) } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 2dbe5b751..48d8b6a2b 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -44,7 +44,7 @@ type SpecialMappable struct { // Preconditions: fr.Length() != 0. func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} - m.EnableLeakCheck() + m.InitRefs() return &m } diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 7c297fb9e..d99be7f46 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File } if f.opts.ManualZeroing { - if err := f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 - } - }); err != nil { + if err := f.manuallyZero(fr); err != nil { return memmap.FileRange{}, err } } @@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error { panic(fmt.Sprintf("invalid range: %v", fr)) } + if f.opts.ManualZeroing { + // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in + // effect. + if err := f.manuallyZero(fr); err != nil { + return err + } + } else { + if err := f.decommitFile(fr); err != nil { + return err + } + } + + f.markDecommitted(fr) + return nil +} + +func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error { + return f.forEachMappingSlice(fr, func(bs []byte) { + for i := range bs { + bs[i] = 0 + } + }) +} + +func (f *MemoryFile) decommitFile(fr memmap.FileRange) error { // "After a successful call, subsequent reads from this range will // return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with // FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2) - err := syscall.Fallocate( + return syscall.Fallocate( int(f.file.Fd()), _FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE, int64(fr.Start), int64(fr.Length())) - if err != nil { - return err - } - f.markDecommitted(fr) - return nil } func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { @@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() { break } - if err := f.Decommit(fr); err != nil { - log.Warningf("Reclaim failed to decommit %v: %v", fr, err) - // Zero the pages manually. This won't reduce memory usage, but at - // least ensures that the pages will be zero when reallocated. - f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 + // If ManualZeroing is in effect, pages will be zeroed on allocation + // and may not be freed by decommitFile, so calling decommitFile is + // unnecessary. + if !f.opts.ManualZeroing { + if err := f.decommitFile(fr); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", fr, err) + // Zero the pages manually. This won't reduce memory usage, but at + // least ensures that the pages will be zero when reallocated. + if err := f.manuallyZero(fr); err != nil { + panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err)) } - }) - // Pretend the pages were decommitted even though they weren't, - // since the memory accounting implementation has no idea how to - // deal with this. - f.markDecommitted(fr) + } } + f.markDecommitted(fr) f.markReclaimed(fr) } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 0a54dd30d..f8ccb7430 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -79,6 +79,25 @@ func bluepillStopGuest(c *vCPU) { c.runData.requestInterruptWindow = 0 } +// bluepillSigBus is reponsible for injecting NMI to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_NMI, 0); errno != 0 { + throw("NMI injection failed") + } +} + +// bluepillHandleEnosys is reponsible for handling enosys error. +// +//go:nosplit +func bluepillHandleEnosys(c *vCPU) { + throw("run failed: ENOSYS") +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. // //go:nosplit @@ -114,3 +133,10 @@ func bluepillReadyStopGuest(c *vCPU) bool { } return true } + +// bluepillArchHandleExit checks architecture specific exitcode. +// +//go:nosplit +func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) { + c.die(bluepillArchContext(context), "unknown") +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 58f3d6fdd..1f09813ba 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -27,15 +27,27 @@ var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = syscall.SIGILL - // vcpuSErr is the event of system error. - vcpuSErr = kvmVcpuEvents{ + // vcpuSErrBounce is the event of system error for bouncing KVM. + vcpuSErrBounce = kvmVcpuEvents{ exception: exception{ sErrPending: 1, - sErrHasEsr: 0, - pad: [6]uint8{0, 0, 0, 0, 0, 0}, - sErrEsr: 1, }, - rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + } + + // vcpuSErrNMI is the event of system error to trigger sigbus. + vcpuSErrNMI = kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 1, + sErrEsr: _ESR_ELx_SERR_NMI, + }, + } + + // vcpuExtDabt is the event of ext_dabt. + vcpuExtDabt = kvmVcpuEvents{ + exception: exception{ + extDabtPending: 1, + }, } ) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index b35c930e2..4d912769a 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -80,18 +80,67 @@ func getHypercallID(addr uintptr) int { // //go:nosplit func bluepillStopGuest(c *vCPU) { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 { - throw("sErr injection failed") + uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { + throw("bounce sErr injection failed") } } +// bluepillSigBus is reponsible for injecting sError to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + // Host must support ARM64_HAS_RAS_EXTN. + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { + if errno == syscall.EINVAL { + throw("No ARM64_HAS_RAS_EXTN feature in host.") + } + throw("nmi sErr injection failed") + } +} + +// bluepillExtDabt is reponsible for injecting external data abort. +// +//go:nosplit +func bluepillExtDabt(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 { + throw("ext_dabt injection failed") + } +} + +// bluepillHandleEnosys is reponsible for handling enosys error. +// +//go:nosplit +func bluepillHandleEnosys(c *vCPU) { + bluepillExtDabt(c) +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection. // //go:nosplit func bluepillReadyStopGuest(c *vCPU) bool { return true } + +// bluepillArchHandleExit checks architecture specific exitcode. +// +//go:nosplit +func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) { + switch c.runData.exitReason { + case _KVM_EXIT_ARM_NISV: + bluepillExtDabt(c) + default: + c.die(bluepillArchContext(context), "unknown") + } +} diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index eb05950cd..8c5369377 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -146,13 +146,11 @@ func bluepillHandler(context unsafe.Pointer) { // MMIO exit we receive EFAULT from the run ioctl. We // always inject an NMI here since we may be in kernel // mode and have interrupts disabled. - if _, _, errno := syscall.RawSyscall( // escapes: no. - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_NMI, 0); errno != 0 { - throw("NMI injection failed") - } + bluepillSigBus(c) continue // Rerun vCPU. + case syscall.ENOSYS: + bluepillHandleEnosys(c) + continue default: throw("run failed") } @@ -225,7 +223,7 @@ func bluepillHandler(context unsafe.Pointer) { c.die(bluepillArchContext(context), "entry failed") return default: - c.die(bluepillArchContext(context), "unknown") + bluepillArchHandleExit(c, context) return } } diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index dd45ad10b..5979aef97 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -158,8 +158,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { // NewAddressSpace returns a new pagetable root. func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. - pageTables := pagetables.New(newAllocator()) - k.machine.mapUpperHalf(pageTables) + pageTables := pagetables.NewWithUpper(newAllocator(), k.machine.upperSharedPageTables, ring0.KernelStartAddress) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 0b06a923a..9db1db4e9 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -47,10 +47,11 @@ type userRegs struct { } type exception struct { - sErrPending uint8 - sErrHasEsr uint8 - pad [6]uint8 - sErrEsr uint64 + sErrPending uint8 + sErrHasEsr uint8 + extDabtPending uint8 + pad [5]uint8 + sErrEsr uint64 } type kvmVcpuEvents struct { diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 6abaa21c4..2492d57be 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -56,6 +56,7 @@ const ( _KVM_EXIT_FAIL_ENTRY = 0x9 _KVM_EXIT_INTERNAL_ERROR = 0x11 _KVM_EXIT_SYSTEM_EVENT = 0x18 + _KVM_EXIT_ARM_NISV = 0x1c ) // KVM capability options. diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 84df0f878..b060d9544 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -38,6 +38,8 @@ const ( _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080 _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082 _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600 + _KVM_ARM64_REGS_TIMER_CNT = 0x603000000013df1a + _KVM_ARM64_REGS_CNTFRQ_EL0 = 0x603000000013df00 ) // Arm64: Architectural Feature Access Control Register EL1. @@ -149,6 +151,9 @@ const ( _ESR_SEGV_PEMERR_L1 = 0xd _ESR_SEGV_PEMERR_L2 = 0xe _ESR_SEGV_PEMERR_L3 = 0xf + + // Custom ISS field definitions for system error. + _ESR_ELx_SERR_NMI = 0x1 ) // Arm64: MMIO base address used to dispatch hypercalls. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 61ed24d01..e2fffc99b 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,6 +41,9 @@ type machine struct { // slots are currently being updated, and the caller should retry. nextSlot uint32 + // upperSharedPageTables tracks the read-only shared upper of all the pagetables. + upperSharedPageTables *pagetables.PageTables + // kernel is the set of global structures. kernel ring0.Kernel @@ -198,9 +202,7 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }, m.maxVCPUs) + m.kernel.Init(m.maxVCPUs) // Pull the maximum slots. maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) @@ -212,6 +214,13 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of slots is %d.", m.maxSlots) m.usedSlots = make([]uintptr, m.maxSlots) + // Create the upper shared pagetables and kernel(sentry) pagetables. + m.upperSharedPageTables = pagetables.New(newAllocator()) + m.mapUpperHalf(m.upperSharedPageTables) + m.upperSharedPageTables.Allocator.(*allocator).base.Drain() + m.upperSharedPageTables.MarkReadOnlyShared() + m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -225,7 +234,6 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion @@ -625,3 +633,35 @@ func (c *vCPU) BounceToKernel() { func (c *vCPU) BounceToHost() { c.bounce(true) } + +// setSystemTimeLegacy calibrates and sets an approximate system time. +func (c *vCPU) setSystemTimeLegacy() error { + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Try to set the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + start := uint64(ktime.Rdtsc()) + if err := c.setTSC(start + (minimum / 2)); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && current <= upperThreshold { + return nil + } + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c67127d95..8e03c310d 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -252,38 +252,6 @@ func (c *vCPU) setSystemTime() error { } } -// setSystemTimeLegacy calibrates and sets an approximate system time. -func (c *vCPU) setSystemTimeLegacy() error { - const minIterations = 10 - minimum := uint64(0) - for iter := 0; ; iter++ { - // Try to set the TSC to an estimate of where it will be - // on the host during a "fast" system call iteration. - start := uint64(ktime.Rdtsc()) - if err := c.setTSC(start + (minimum / 2)); err != nil { - return err - } - // See if this is our new minimum call time. Note that this - // serves two functions: one, we make sure that we are - // accurately predicting the offset we need to set. Second, we - // don't want to do the final set on a slow call, which could - // produce a really bad result. - end := uint64(ktime.Rdtsc()) - if end < start { - continue // Totally bogus: unstable TSC? - } - current := end - start - if current < minimum || iter == 0 { - minimum = current // Set our new minimum. - } - // Is this past minIterations and within ~10% of minimum? - upperThreshold := (((minimum << 3) + minimum) >> 3) - if iter >= minIterations && current <= upperThreshold { - return nil - } - } -} - // nonCanonical generates a canonical address return. // //go:nosplit @@ -464,30 +432,27 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } -var execRegions = func() (regions []region) { +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + // Map all the executible regions so that all the entry functions + // are mapped in the upper half. applyVirtualRegions(func(vr virtualRegion) { if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { return } + if vr.accessType.Execute { - regions = append(regions, vr.region) + r := vr.region + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) } }) - return -}() - -func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { - for _, r := range execRegions { - physical, length, ok := translateToPhysical(r.virtual) - if !ok || length < r.length { - panic("impossilbe translation") - } - pageTable.Map( - usermem.Addr(ring0.KernelStartAddress|r.virtual), - r.length, - pagetables.MapOpts{AccessType: usermem.Execute}, - physical) - } for start, end := range m.kernel.EntryRegions() { regionLen := end - start physical, length, ok := translateToPhysical(start) diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index a163f956d..3f5be276b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -159,9 +159,33 @@ func (c *vCPU) initArchState() error { } c.floatingPointState = arch.NewFloatingPointData() + + return c.setSystemTime() +} + +// setTSC sets the counter Virtual Offset. +func (c *vCPU) setTSC(value uint64) error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + data = uint64(value) + + if err := c.setOneRegister(®); err != nil { + return err + } + return nil } +// setSystemTime sets the vCPU to the system time. +func (c *vCPU) setSystemTime() error { + return c.setSystemTimeLegacy() +} + //go:nosplit func (c *vCPU) loadSegments(tid uint64) { // TODO(gvisor.dev/issue/1238): TLS is not supported. @@ -197,7 +221,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info) } else if !ring0.IsCanonical(regs.Sp) { - return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) + return nonCanonical(regs.Sp, int32(syscall.SIGSEGV), info) } // Assign PCIDs. @@ -233,16 +257,12 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.PageFault: return c.fault(int32(syscall.SIGSEGV), info) + case ring0.El0ErrNMI: + return c.fault(int32(syscall.SIGBUS), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt - case ring0.El0Sync_undef, - ring0.El1Sync_undef: - *info = arch.SignalInfo{ - Signo: int32(syscall.SIGILL), - Code: 1, // ILL_ILLOPC (illegal opcode). - } - info.SetAddr(switchOpts.Registers.Pc) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal + case ring0.El0SyncUndef: + return c.fault(int32(syscall.SIGILL), info) default: panic(fmt.Sprintf("unexpected vector: 0x%x", vector)) } diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index f56aa3b79..571bfcc2e 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -18,8 +18,8 @@ // // In a nutshell, it works as follows: // -// The creation of a new address space creates a new child processes with a -// single thread which is traced by a single goroutine. +// The creation of a new address space creates a new child process with a single +// thread which is traced by a single goroutine. // // A context is just a collection of temporary variables. Calling Switch on a // context does the following: diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 812ab80ef..aacd7ce70 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // facilitate vsyscall emulation. See patchSignalInfo. patchSignalInfo(regs, &c.signalInfo) return false - } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) { + } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) { // The signal was generated by this process. That means // that it was an interrupt or something else that we // should bail for. Note that we ignore signals diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 679b287c3..2852b7387 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "arch_genrule", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -39,19 +39,19 @@ go_template_instance( template = ":defs_arm64", ) -genrule( +arch_genrule( name = "entry_impl_amd64", srcs = ["entry_amd64.s"], outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) -genrule( +arch_genrule( name = "entry_impl_arm64", srcs = ["entry_arm64.s"], outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) @@ -72,7 +72,6 @@ go_library( "lib_amd64.s", "lib_arm64.go", "lib_arm64.s", - "lib_arm64_unsafe.go", "ring0.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 87a573cc4..c51df2811 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -58,46 +58,56 @@ type Vector uintptr // Exception vectors. const ( - El1SyncInvalid = iota - El1IrqInvalid - El1FiqInvalid - El1ErrorInvalid + El1InvSync = iota + El1InvIrq + El1InvFiq + El1InvError + El1Sync El1Irq El1Fiq - El1Error + El1Err + El0Sync El0Irq El0Fiq - El0Error - El0Sync_invalid - El0Irq_invalid - El0Fiq_invalid - El0Error_invalid - El1Sync_da - El1Sync_ia - El1Sync_sp_pc - El1Sync_undef - El1Sync_dbg - El1Sync_inv - El0Sync_svc - El0Sync_da - El0Sync_ia - El0Sync_fpsimd_acc - El0Sync_sve_acc - El0Sync_sys - El0Sync_sp_pc - El0Sync_undef - El0Sync_dbg - El0Sync_inv + El0Err + + El0InvSync + El0InvIrq + El0InvFiq + El0InvErr + + El1SyncDa + El1SyncIa + El1SyncSpPc + El1SyncUndef + El1SyncDbg + El1SyncInv + + El0SyncSVC + El0SyncDa + El0SyncIa + El0SyncFpsimdAcc + El0SyncSveAcc + El0SyncFpsimdExc + El0SyncSys + El0SyncSpPc + El0SyncUndef + El0SyncDbg + El0SyncInv + + El0ErrNMI + El0ErrBounce + _NR_INTERRUPTS ) // System call vectors. const ( - Syscall Vector = El0Sync_svc - PageFault Vector = El0Sync_da - VirtualizationException Vector = El0Error + Syscall Vector = El0SyncSVC + PageFault Vector = El0SyncDa + VirtualizationException Vector = El0ErrBounce ) // VirtualAddressBits returns the number bits available for virtual addresses. diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index e6daf24df..f9765771e 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -23,6 +23,9 @@ import ( // // This contains global state, shared by multiple CPUs. type Kernel struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables + KernelArchState } diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 00899273e..7a2275558 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -66,17 +66,9 @@ var ( KernelDataSegment SegmentDescriptor ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts - - // cpuEntries is array of kernelEntry for all cpus + // cpuEntries is array of kernelEntry for all cpus. cpuEntries []kernelEntry // globalIDT is our set of interrupt gates. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 508236e46..a014dcbc0 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -32,15 +32,8 @@ var ( KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts } // CPUArchState contains CPU-specific arch state. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2370a9276..cf0bf3528 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -132,40 +132,6 @@ MOVD offset+PTRACE_R29(reg), R29; \ MOVD offset+PTRACE_R30(reg), R30; -// NOP-s -#define nop31Instructions() \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; - #define ESR_ELx_EC_UNKNOWN (0x00) #define ESR_ELx_EC_WFx (0x01) /* Unallocated EC: 0x02 */ @@ -288,6 +254,10 @@ #define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) #define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) +/* ISS field definitions for system error */ +#define ESR_ELx_SERR_MASK (0x1) +#define ESR_ELx_SERR_NMI (0x1) + // LOAD_KERNEL_ADDRESS loads a kernel address. #define LOAD_KERNEL_ADDRESS(from, to) \ MOVD from, to; \ @@ -320,6 +290,18 @@ MOVD CPU_TTBR0_KVM(from), RSV_REG; \ MSR RSV_REG, TTBR0_EL1; +TEXT ·EnableVFP(SB),NOSPLIT,$0 + MOVD $FPEN_ENABLE, R0 + WORD $0xd5181040 //MSR R0, CPACR_EL1 + ISB $15 + RET + +TEXT ·DisableVFP(SB),NOSPLIT,$0 + MOVD $0, R0 + WORD $0xd5181040 //MSR R0, CPACR_EL1 + ISB $15 + RET + #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ WORD $0xd5181040; \ //MSR R0, CPACR_EL1 @@ -366,6 +348,25 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// EXCEPTION_EL0 is a common el0 exception handler function. +#define EXCEPTION_EL0(vector) \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + WORD $0xd538601a; \ //MRS FAR_EL1, R26 + MOVD R26, CPU_FAULT_ADDR(RSV_REG); \ + MOVD $1, R3; \ + MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user. + MOVD $vector, R3; \ + MOVD R3, CPU_VECTOR_CODE(RSV_REG); \ + MRS ESR_EL1, R3; \ + MOVD R3, CPU_ERROR_CODE(RSV_REG); \ + B ·kernelExitToEl1(SB); + +// EXCEPTION_EL1 is a common el1 exception handler function. +#define EXCEPTION_EL1(vector) \ + MOVD $vector, R3; \ + MOVD R3, 8(RSP); \ + B ·HaltEl1ExceptionAndResume(SB); + // storeAppASID writes the application's asid value. TEXT ·storeAppASID(SB),NOSPLIT,$0-8 MOVD asid+0(FP), R1 @@ -413,6 +414,16 @@ TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0 CALL ·kernelSyscall(SB) // Call the trampoline. B ·kernelExitToEl1(SB) // Resume. +// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume. +TEXT ·HaltEl1ExceptionAndResume(SB),NOSPLIT,$0-8 + WORD $0xd538d092 // MRS TPIDR_EL1, R18 + MOVD CPU_SELF(RSV_REG), R3 // Load vCPU. + MOVD R3, 8(RSP) // First argument (vCPU). + MOVD vector+0(FP), R3 + MOVD R3, 16(RSP) // Second argument (vector). + CALL ·kernelException(SB) // Call the trampoline. + B ·kernelExitToEl1(SB) // Resume. + // Shutdown stops the guest. TEXT ·Shutdown(SB),NOSPLIT,$0 // PSCI EVENT. @@ -503,6 +514,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1 MSR R1, ELR_EL1 + // restore sentry's tls. + MOVD CPU_REGISTERS+PTRACE_TLS(RSV_REG), R1 + MSR R1, TPIDR_EL0 + MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP @@ -571,39 +586,22 @@ TEXT ·El1_sync(SB),NOSPLIT,$0 B el1_invalid el1_da: + EXCEPTION_EL1(El1SyncDa) el1_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $0, CPU_ERROR_TYPE(RSV_REG) - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·HaltAndResume(SB) - + EXCEPTION_EL1(El1SyncIa) el1_sp_pc: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncSpPc) el1_undef: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncUndef) el1_svc: - MOVD $0, CPU_ERROR_CODE(RSV_REG) - MOVD $0, CPU_ERROR_TYPE(RSV_REG) B ·HaltEl1SvcAndResume(SB) - el1_dbg: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncDbg) el1_fpsimd_acc: VFP_ENABLE B ·kernelExitToEl1(SB) // Resume. - el1_invalid: - B ·Shutdown(SB) + EXCEPTION_EL1(El1SyncInv) // El1_irq is the handler for El1_irq. TEXT ·El1_irq(SB),NOSPLIT,$0 @@ -659,45 +657,21 @@ el0_svc: el0_da: el0_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $1, R3 - MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - MRS ESR_EL1, R3 - MOVD R3, CPU_ERROR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) - + EXCEPTION_EL0(PageFault) el0_fpsimd_acc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncFpsimdAcc) el0_sve_acc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncSveAcc) el0_fpsimd_exc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncFpsimdExc) el0_sp_pc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncSpPc) el0_undef: - MOVD $El0Sync_undef, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) - + EXCEPTION_EL0(El0SyncUndef) el0_dbg: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncDbg) el0_invalid: - B ·Shutdown(SB) + EXCEPTION_EL0(El0SyncInv) TEXT ·El0_irq(SB),NOSPLIT,$0 B ·Shutdown(SB) @@ -707,6 +681,29 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0 TEXT ·El0_error(SB),NOSPLIT,$0 KERNEL_ENTRY_FROM_EL0 + WORD $0xd5385219 // MRS ESR_EL1, R25 + AND $ESR_ELx_SERR_MASK, R25, R24 + CMP $ESR_ELx_SERR_NMI, R24 + BEQ el0_nmi + B el0_bounce +el0_nmi: + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + WORD $0xd538601a //MRS FAR_EL1, R26 + + MOVD R26, CPU_FAULT_ADDR(RSV_REG) + + MOVD $1, R3 + MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. + + MOVD $El0ErrNMI, R3 + MOVD R3, CPU_VECTOR_CODE(RSV_REG) + + MRS ESR_EL1, R3 + MOVD R3, CPU_ERROR_CODE(RSV_REG) + + B ·kernelExitToEl1(SB) + +el0_bounce: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -718,7 +715,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0 MOVD $VirtualizationException, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + B ·kernelExitToEl1(SB) TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) @@ -733,79 +730,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) // Vectors implements exception vector table. +// The start address of exception vector table should be 11-bits aligned. +// For detail, please refer to arm developer document: +// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table +// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S TEXT ·Vectors(SB),NOSPLIT,$0 + PCALIGN $2048 B ·El1_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error_invalid(SB) - nop31Instructions() - - // The exception-vector-table is required to be 11-bits aligned. - // Please see Linux source code as reference: arch/arm64/kernel/entry.s. - // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - WORD $0xd503201f //nop - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 9742308d8..a9703baf6 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,6 +24,9 @@ go_binary( "defs_impl_arm64.go", "main.go", ], + # Use the libc malloc to avoid any extra dependencies. This is required to + # pass the sentry deps test. + system_malloc = True, visibility = [ "//pkg/sentry/platform/kvm:__pkg__", "//pkg/sentry/platform/ring0:__pkg__", diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 264be23d3..292f9d0cc 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -16,11 +16,9 @@ package ring0 // Init initializes a new kernel. // -// N.B. that constraints on KernelOpts must be satisfied. -// //go:nosplit -func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { - k.init(opts, maxCPUs) +func (k *Kernel) Init(maxCPUs int) { + k.init(maxCPUs) } // Halt halts execution. diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index 3a9dff4cc..b55dc29b3 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -24,10 +24,7 @@ import ( ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables - +func (k *Kernel) init(maxCPUs int) { entrySize := reflect.TypeOf(kernelEntry{}).Size() var ( entries []kernelEntry diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 68291b504..bffe27e5c 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -24,10 +24,12 @@ func HaltAndResume() //go:nosplit func HaltEl1SvcAndResume() +// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume. +//go:nosplit +func HaltEl1ExceptionAndResume() + // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables +func (k *Kernel) init(maxCPUs int) { } // init initializes architecture-specific state. @@ -69,11 +71,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Pstate &= ^uint64(PsrFlagsClear) regs.Pstate |= UserFlagsSet + EnableVFP() LoadFloatingPoint(switchOpts.FloatingPointState) kernelExitToEl0() SaveFloatingPoint(switchOpts.FloatingPointState) + DisableVFP() vector = c.vecCode diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 456107cd8..29b3efd34 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -53,9 +53,13 @@ func LoadFloatingPoint(*byte) // SaveFloatingPoint saves floating point state. func SaveFloatingPoint(*byte) +// EnableVFP enables fpsimd. +func EnableVFP() + +// DisableVFP disables fpsimd. +func DisableVFP() + // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. -func Init() { - rewriteVectors() -} +func Init() {} diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index da9d3cf55..6f4923539 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -35,62 +35,47 @@ TEXT ·CPACREL1(SB),NOSPLIT,$0-8 RET TEXT ·GetFPCR(SB),NOSPLIT,$0-8 - WORD $0xd53b4201 // MRS NZCV, R1 + MOVD FPCR, R1 MOVD R1, ret+0(FP) RET TEXT ·GetFPSR(SB),NOSPLIT,$0-8 - WORD $0xd53b4421 // MRS FPSR, R1 + MOVD FPSR, R1 MOVD R1, ret+0(FP) RET TEXT ·SetFPCR(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R1 - WORD $0xd51b4201 // MSR R1, NZCV + MOVD R1, FPCR RET TEXT ·SetFPSR(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R1 - WORD $0xd51b4421 // MSR R1, FPSR + MOVD R1, FPSR RET TEXT ·SaveVRegs(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R0 // Skip aarch64_ctx, fpsr, fpcr. - FMOVD F0, 16*1(R0) - FMOVD F1, 16*2(R0) - FMOVD F2, 16*3(R0) - FMOVD F3, 16*4(R0) - FMOVD F4, 16*5(R0) - FMOVD F5, 16*6(R0) - FMOVD F6, 16*7(R0) - FMOVD F7, 16*8(R0) - FMOVD F8, 16*9(R0) - FMOVD F9, 16*10(R0) - FMOVD F10, 16*11(R0) - FMOVD F11, 16*12(R0) - FMOVD F12, 16*13(R0) - FMOVD F13, 16*14(R0) - FMOVD F14, 16*15(R0) - FMOVD F15, 16*16(R0) - FMOVD F16, 16*17(R0) - FMOVD F17, 16*18(R0) - FMOVD F18, 16*19(R0) - FMOVD F19, 16*20(R0) - FMOVD F20, 16*21(R0) - FMOVD F21, 16*22(R0) - FMOVD F22, 16*23(R0) - FMOVD F23, 16*24(R0) - FMOVD F24, 16*25(R0) - FMOVD F25, 16*26(R0) - FMOVD F26, 16*27(R0) - FMOVD F27, 16*28(R0) - FMOVD F28, 16*29(R0) - FMOVD F29, 16*30(R0) - FMOVD F30, 16*31(R0) - FMOVD F31, 16*32(R0) - ISB $15 + ADD $16, R0, R0 + + WORD $0xad000400 // stp q0, q1, [x0] + WORD $0xad010c02 // stp q2, q3, [x0, #32] + WORD $0xad021404 // stp q4, q5, [x0, #64] + WORD $0xad031c06 // stp q6, q7, [x0, #96] + WORD $0xad042408 // stp q8, q9, [x0, #128] + WORD $0xad052c0a // stp q10, q11, [x0, #160] + WORD $0xad06340c // stp q12, q13, [x0, #192] + WORD $0xad073c0e // stp q14, q15, [x0, #224] + WORD $0xad084410 // stp q16, q17, [x0, #256] + WORD $0xad094c12 // stp q18, q19, [x0, #288] + WORD $0xad0a5414 // stp q20, q21, [x0, #320] + WORD $0xad0b5c16 // stp q22, q23, [x0, #352] + WORD $0xad0c6418 // stp q24, q25, [x0, #384] + WORD $0xad0d6c1a // stp q26, q27, [x0, #416] + WORD $0xad0e741c // stp q28, q29, [x0, #448] + WORD $0xad0f7c1e // stp q30, q31, [x0, #480] RET @@ -98,39 +83,24 @@ TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R0 // Skip aarch64_ctx, fpsr, fpcr. - FMOVD 16*1(R0), F0 - FMOVD 16*2(R0), F1 - FMOVD 16*3(R0), F2 - FMOVD 16*4(R0), F3 - FMOVD 16*5(R0), F4 - FMOVD 16*6(R0), F5 - FMOVD 16*7(R0), F6 - FMOVD 16*8(R0), F7 - FMOVD 16*9(R0), F8 - FMOVD 16*10(R0), F9 - FMOVD 16*11(R0), F10 - FMOVD 16*12(R0), F11 - FMOVD 16*13(R0), F12 - FMOVD 16*14(R0), F13 - FMOVD 16*15(R0), F14 - FMOVD 16*16(R0), F15 - FMOVD 16*17(R0), F16 - FMOVD 16*18(R0), F17 - FMOVD 16*19(R0), F18 - FMOVD 16*20(R0), F19 - FMOVD 16*21(R0), F20 - FMOVD 16*22(R0), F21 - FMOVD 16*23(R0), F22 - FMOVD 16*24(R0), F23 - FMOVD 16*25(R0), F24 - FMOVD 16*26(R0), F25 - FMOVD 16*27(R0), F26 - FMOVD 16*28(R0), F27 - FMOVD 16*29(R0), F28 - FMOVD 16*30(R0), F29 - FMOVD 16*31(R0), F30 - FMOVD 16*32(R0), F31 - ISB $15 + ADD $16, R0, R0 + + WORD $0xad400400 // ldp q0, q1, [x0] + WORD $0xad410c02 // ldp q2, q3, [x0, #32] + WORD $0xad421404 // ldp q4, q5, [x0, #64] + WORD $0xad431c06 // ldp q6, q7, [x0, #96] + WORD $0xad442408 // ldp q8, q9, [x0, #128] + WORD $0xad452c0a // ldp q10, q11, [x0, #160] + WORD $0xad46340c // ldp q12, q13, [x0, #192] + WORD $0xad473c0e // ldp q14, q15, [x0, #224] + WORD $0xad484410 // ldp q16, q17, [x0, #256] + WORD $0xad494c12 // ldp q18, q19, [x0, #288] + WORD $0xad4a5414 // ldp q20, q21, [x0, #320] + WORD $0xad4b5c16 // ldp q22, q23, [x0, #352] + WORD $0xad4c6418 // ldp q24, q25, [x0, #384] + WORD $0xad4d6c1a // ldp q26, q27, [x0, #416] + WORD $0xad4e741c // ldp q28, q29, [x0, #448] + WORD $0xad4f7c1e // ldp q30, q31, [x0, #480] RET @@ -140,40 +110,26 @@ TEXT ·LoadFloatingPoint(SB),NOSPLIT,$0-8 MOVD 0(R0), R1 MOVD R1, FPSR MOVD 8(R0), R1 - MOVD R1, NZCV - - FMOVD 16*1(R0), F0 - FMOVD 16*2(R0), F1 - FMOVD 16*3(R0), F2 - FMOVD 16*4(R0), F3 - FMOVD 16*5(R0), F4 - FMOVD 16*6(R0), F5 - FMOVD 16*7(R0), F6 - FMOVD 16*8(R0), F7 - FMOVD 16*9(R0), F8 - FMOVD 16*10(R0), F9 - FMOVD 16*11(R0), F10 - FMOVD 16*12(R0), F11 - FMOVD 16*13(R0), F12 - FMOVD 16*14(R0), F13 - FMOVD 16*15(R0), F14 - FMOVD 16*16(R0), F15 - FMOVD 16*17(R0), F16 - FMOVD 16*18(R0), F17 - FMOVD 16*19(R0), F18 - FMOVD 16*20(R0), F19 - FMOVD 16*21(R0), F20 - FMOVD 16*22(R0), F21 - FMOVD 16*23(R0), F22 - FMOVD 16*24(R0), F23 - FMOVD 16*25(R0), F24 - FMOVD 16*26(R0), F25 - FMOVD 16*27(R0), F26 - FMOVD 16*28(R0), F27 - FMOVD 16*29(R0), F28 - FMOVD 16*30(R0), F29 - FMOVD 16*31(R0), F30 - FMOVD 16*32(R0), F31 + MOVD R1, FPCR + + ADD $16, R0, R0 + + WORD $0xad400400 // ldp q0, q1, [x0] + WORD $0xad410c02 // ldp q2, q3, [x0, #32] + WORD $0xad421404 // ldp q4, q5, [x0, #64] + WORD $0xad431c06 // ldp q6, q7, [x0, #96] + WORD $0xad442408 // ldp q8, q9, [x0, #128] + WORD $0xad452c0a // ldp q10, q11, [x0, #160] + WORD $0xad46340c // ldp q12, q13, [x0, #192] + WORD $0xad473c0e // ldp q14, q15, [x0, #224] + WORD $0xad484410 // ldp q16, q17, [x0, #256] + WORD $0xad494c12 // ldp q18, q19, [x0, #288] + WORD $0xad4a5414 // ldp q20, q21, [x0, #320] + WORD $0xad4b5c16 // ldp q22, q23, [x0, #352] + WORD $0xad4c6418 // ldp q24, q25, [x0, #384] + WORD $0xad4d6c1a // ldp q26, q27, [x0, #416] + WORD $0xad4e741c // ldp q28, q29, [x0, #448] + WORD $0xad4f7c1e // ldp q30, q31, [x0, #480] RET @@ -182,40 +138,26 @@ TEXT ·SaveFloatingPoint(SB),NOSPLIT,$0-8 MOVD FPSR, R1 MOVD R1, 0(R0) - MOVD NZCV, R1 + MOVD FPCR, R1 MOVD R1, 8(R0) - FMOVD F0, 16*1(R0) - FMOVD F1, 16*2(R0) - FMOVD F2, 16*3(R0) - FMOVD F3, 16*4(R0) - FMOVD F4, 16*5(R0) - FMOVD F5, 16*6(R0) - FMOVD F6, 16*7(R0) - FMOVD F7, 16*8(R0) - FMOVD F8, 16*9(R0) - FMOVD F9, 16*10(R0) - FMOVD F10, 16*11(R0) - FMOVD F11, 16*12(R0) - FMOVD F12, 16*13(R0) - FMOVD F13, 16*14(R0) - FMOVD F14, 16*15(R0) - FMOVD F15, 16*16(R0) - FMOVD F16, 16*17(R0) - FMOVD F17, 16*18(R0) - FMOVD F18, 16*19(R0) - FMOVD F19, 16*20(R0) - FMOVD F20, 16*21(R0) - FMOVD F21, 16*22(R0) - FMOVD F22, 16*23(R0) - FMOVD F23, 16*24(R0) - FMOVD F24, 16*25(R0) - FMOVD F25, 16*26(R0) - FMOVD F26, 16*27(R0) - FMOVD F27, 16*28(R0) - FMOVD F28, 16*29(R0) - FMOVD F29, 16*30(R0) - FMOVD F30, 16*31(R0) - FMOVD F31, 16*32(R0) + ADD $16, R0, R0 + + WORD $0xad000400 // stp q0, q1, [x0] + WORD $0xad010c02 // stp q2, q3, [x0, #32] + WORD $0xad021404 // stp q4, q5, [x0, #64] + WORD $0xad031c06 // stp q6, q7, [x0, #96] + WORD $0xad042408 // stp q8, q9, [x0, #128] + WORD $0xad052c0a // stp q10, q11, [x0, #160] + WORD $0xad06340c // stp q12, q13, [x0, #192] + WORD $0xad073c0e // stp q14, q15, [x0, #224] + WORD $0xad084410 // stp q16, q17, [x0, #256] + WORD $0xad094c12 // stp q18, q19, [x0, #288] + WORD $0xad0a5414 // stp q20, q21, [x0, #320] + WORD $0xad0b5c16 // stp q22, q23, [x0, #352] + WORD $0xad0c6418 // stp q24, q25, [x0, #384] + WORD $0xad0d6c1a // stp q26, q27, [x0, #416] + WORD $0xad0e741c // stp q28, q29, [x0, #448] + WORD $0xad0f7c1e // stp q30, q31, [x0, #480] RET diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go deleted file mode 100644 index c05166fea..000000000 --- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package ring0 - -import ( - "reflect" - "syscall" - "unsafe" - - "gvisor.dev/gvisor/pkg/safecopy" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - nopInstruction = 0xd503201f - instSize = unsafe.Sizeof(uint32(0)) - vectorsRawLen = 0x800 -) - -func unsafeSlice(addr uintptr, length int) (slice []uint32) { - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - hdr.Data = addr - hdr.Len = length / int(instSize) - hdr.Cap = length / int(instSize) - return slice -} - -// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. -// -// According to the design documentation of Arm64, -// the start address of exception vector table should be 11-bits aligned. -// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S -// But, we can't align a function's start address to a specific address by using golang. -// We have raised this question in golang community: -// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I -// This function will be removed when golang supports this feature. -// -// There are 2 jobs were implemented in this function: -// 1, move the start address of exception vector table into the specific address. -// 2, modify the offset of each instruction. -func rewriteVectors() { - vectorsBegin := reflect.ValueOf(Vectors).Pointer() - - // The exception-vector-table is required to be 11-bits aligned. - // And the size is 0x800. - // Please see the documentation as reference: - // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table - // - // But, golang does not allow to set a function's address to a specific value. - // So, for gvisor, I defined the size of exception-vector-table as 4K, - // filled the 2nd 2K part with NOP-s. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - // - // So, the prerequisite for this function to work correctly is: - // vectorsSafeLen >= 0x1000 - // vectorsRawLen = 0x800 - vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin) - if vectorsSafeLen < 2*vectorsRawLen { - panic("Can't update vectors") - } - - vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32 - vectorsRawLen32 := vectorsRawLen / int(instSize) - - offset := vectorsBegin & (1<<11 - 1) - if offset != 0 { - offset = 1<<11 - offset - } - - pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1) - - _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } - - offset = offset / instSize // By index, not bytes. - // Move exception-vector-table into the specific address, should uses memmove here. - for i := 1; i <= vectorsRawLen32; i++ { - vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i] - } - - // Adjust branch since instruction was moved forward. - for i := 0; i < vectorsRawLen32; i++ { - if vectorsSafeTable[int(offset)+i] != nopInstruction { - vectorsSafeTable[int(offset)+i] -= uint32(offset) - } - } - - _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } -} diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 45eba960d..b5652deb9 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -47,43 +47,37 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) fmt.Fprintf(w, "\n// Vectors.\n") - fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid) - fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid) - fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid) - fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid) fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync) fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq) fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq) - fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error) + fmt.Fprintf(w, "#define El1Err 0x%02x\n", El1Err) fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync) fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq) fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq) - fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error) + fmt.Fprintf(w, "#define El0Err 0x%02x\n", El0Err) - fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid) - fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid) - fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid) - fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid) + fmt.Fprintf(w, "#define El1SyncDa 0x%02x\n", El1SyncDa) + fmt.Fprintf(w, "#define El1SyncIa 0x%02x\n", El1SyncIa) + fmt.Fprintf(w, "#define El1SyncSpPc 0x%02x\n", El1SyncSpPc) + fmt.Fprintf(w, "#define El1SyncUndef 0x%02x\n", El1SyncUndef) + fmt.Fprintf(w, "#define El1SyncDbg 0x%02x\n", El1SyncDbg) + fmt.Fprintf(w, "#define El1SyncInv 0x%02x\n", El1SyncInv) - fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da) - fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia) - fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc) - fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef) - fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg) - fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv) + fmt.Fprintf(w, "#define El0SyncSVC 0x%02x\n", El0SyncSVC) + fmt.Fprintf(w, "#define El0SyncDa 0x%02x\n", El0SyncDa) + fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa) + fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc) + fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc) + fmt.Fprintf(w, "#define El0SyncFpsimdExc 0x%02x\n", El0SyncFpsimdExc) + fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys) + fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc) + fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef) + fmt.Fprintf(w, "#define El0SyncDbg 0x%02x\n", El0SyncDbg) + fmt.Fprintf(w, "#define El0SyncInv 0x%02x\n", El0SyncInv) - fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc) - fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da) - fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia) - fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc) - fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc) - fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys) - fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc) - fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef) - fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg) - fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv) + fmt.Fprintf(w, "#define El0ErrNMI 0x%02x\n", El0ErrNMI) fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 7f18ac296..7605d0cb2 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -30,6 +30,10 @@ type PageTables struct { Allocator Allocator // root is the pagetable root. + // + // For same archs such as amd64, the upper of the PTEs is cloned + // from and owned by upperSharedPageTables which are shared among + // many PageTables if upperSharedPageTables is not nil. root *PTEs // rootPhysical is the cached physical address of the root. @@ -39,15 +43,64 @@ type PageTables struct { // archPageTables includes architecture-specific features. archPageTables + + // upperSharedPageTables represents a read-only shared upper + // of the Pagetable. When it is not nil, the upper is not + // allowed to be modified. + upperSharedPageTables *PageTables + + // upperStart is the start address of the upper portion that + // are shared from upperSharedPageTables + upperStart uintptr + + // readOnlyShared indicates the Pagetables are read-only and + // own the ranges that are shared with other Pagetables. + readOnlyShared bool } -// New returns new PageTables. -func New(a Allocator) *PageTables { +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) +} + +// NewWithUpper returns new PageTables. +// +// upperSharedPageTables are used for mapping the upper of addresses, +// starting at upperStart. These pageTables should not be touched (as +// invalidations may be incorrect) after they are passed as an +// upperSharedPageTables. Only when all dependent PageTables are gone +// may they be used. The intenteded use case is for kernel page tables, +// which are static and fixed. +// +// Precondition: upperStart must be between canonical ranges. +// Precondition: upperStart must be pgdSize aligned. +// precondition: upperSharedPageTables must be marked read-only shared. +func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables { p := new(PageTables) p.Init(a) + + if upperSharedPageTables != nil { + if !upperSharedPageTables.readOnlyShared { + panic("Only read-only shared pagetables can be used as upper") + } + p.upperSharedPageTables = upperSharedPageTables + p.upperStart = upperStart + } + + p.InitArch(a) + return p } +// New returns new PageTables. +func New(a Allocator) *PageTables { + return NewWithUpper(a, nil, 0) +} + // mapVisitor is used for map. type mapVisitor struct { target uintptr // Input. @@ -90,6 +143,21 @@ func (*mapVisitor) requiresSplit() bool { return true } // //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } if !opts.AccessType.Any() { return p.Unmap(addr, length) } @@ -128,12 +196,27 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // True is returned iff there was a previous mapping in the range. // -// Precondition: addr & length must be page-aligned. +// Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack // //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } w := unmapWalker{ pageTables: p, visitor: unmapVisitor{ @@ -218,3 +301,10 @@ func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) w.iterateRange(uintptr(addr), uintptr(addr)+1) return w.visitor.physical + offset, w.visitor.opts } + +// MarkReadOnlyShared marks the pagetables read-only and can be shared. +// +// It is usually used on the pagetables that are used as the upper +func (p *PageTables) MarkReadOnlyShared() { + p.readOnlyShared = true +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index 0c153cf8c..4bdde8448 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -41,13 +41,34 @@ const ( entriesPerPage = 512 ) -// Init initializes a set of PageTables. +// InitArch does some additional initialization related to the architecture. // //go:nosplit -func (p *PageTables) Init(allocator Allocator) { - p.Allocator = allocator - p.root = p.Allocator.NewPTEs() - p.rootPhysical = p.Allocator.PhysicalFor(p.root) +func (p *PageTables) InitArch(allocator Allocator) { + if p.upperSharedPageTables != nil { + p.cloneUpperShared() + } +} + +func pgdIndex(upperStart uintptr) uintptr { + if upperStart&(pgdSize-1) != 0 { + panic("upperStart should be pgd size aligned") + } + if upperStart >= upperBottom { + return entriesPerPage/2 + (upperStart-upperBottom)/pgdSize + } + if upperStart < lowerTop { + return upperStart / pgdSize + } + panic("upperStart should be in canonical range") +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + start := pgdIndex(p.upperStart) + copy(p.root[start:entriesPerPage], p.upperSharedPageTables.root[start:entriesPerPage]) } // PTEs is a collection of entries. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go index 1a49f12a2..ad0e30c88 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -36,21 +36,34 @@ const ( pudSize = 1 << pudShift pgdSize = 1 << pgdShift - ttbrASIDOffset = 55 + ttbrASIDOffset = 48 ttbrASIDMask = 0xff entriesPerPage = 512 ) -// Init initializes a set of PageTables. +// InitArch does some additional initialization related to the architecture. // //go:nosplit -func (p *PageTables) Init(allocator Allocator) { - p.Allocator = allocator - p.root = p.Allocator.NewPTEs() - p.rootPhysical = p.Allocator.PhysicalFor(p.root) - p.archPageTables.root = p.Allocator.NewPTEs() - p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) +func (p *PageTables) InitArch(allocator Allocator) { + if p.upperSharedPageTables != nil { + p.cloneUpperShared() + } else { + p.archPageTables.root = p.Allocator.NewPTEs() + p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) + } +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + if p.upperStart != upperBottom { + panic("upperStart should be the same as upperBottom") + } + + p.archPageTables.root = p.upperSharedPageTables.archPageTables.root + p.archPageTables.rootPhysical = p.upperSharedPageTables.archPageTables.rootPhysical } // PTEs is a collection of entries. diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index a3f775d15..cc1f6bfcc 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index ca16d0381..fb7c5dc61 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -23,7 +23,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..b88cdca48 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -344,21 +343,34 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { } // PackIPPacketInfo packs an IP_PKTINFO socket control message. -func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { - var p linux.ControlMessageIPPacketInfo - p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) - +func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, linux.IP_PKTINFO, t.Arch().Width(), - p, + packetInfo, ) } +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { + var level uint32 + var optType uint32 + switch originalDstAddress.(type) { + case *linux.SockAddrInet: + level = linux.SOL_IP + optType = linux.IP_RECVORIGDSTADDR + case *linux.SockAddrInet6: + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + default: + panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg") + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), originalDstAddress) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -384,7 +396,11 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt } if cmsgs.IP.HasIPPacketInfo { - buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) + } + + if cmsgs.IP.OriginalDstAddress != nil { + buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } return buf @@ -416,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageTClass) } - return space -} + if cmsgs.IP.HasIPPacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) + } -// NewIPPacketInfo returns the IPPacketInfo struct. -func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { - var p tcpip.IPPacketInfo - p.NIC = tcpip.NICID(packetInfo.NIC) - copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) - copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + if cmsgs.IP.OriginalDstAddress != nil { + space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) + } - return p + return space } // Parse parses a raw socket control message into portable objects. @@ -489,6 +503,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.Unix.Credentials = scmCreds i += binary.AlignUp(length, width) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + return socket.ControlMessages{}, syserror.EINVAL + } + cmsgs.IP.HasTimestamp = true + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp) + i += binary.AlignUp(length, width) + default: // Unknown message type. return socket.ControlMessages{}, syserror.EINVAL @@ -512,7 +534,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + cmsgs.IP.PacketInfo = packetInfo + i += binary.AlignUp(length, width) + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) default: @@ -528,6 +559,15 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) i += binary.AlignUp(length, width) + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index d9621968c..37d02948f 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -24,6 +24,8 @@ import ( ) // SCMRightsVFS2 represents a SCM_RIGHTS socket control message. +// +// +stateify savable type SCMRightsVFS2 interface { transport.RightsControlMessage @@ -34,9 +36,11 @@ type SCMRightsVFS2 interface { Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool) } -// RightsFiles represents a SCM_RIGHTS socket control message. A reference is -// maintained for each vfs.FileDescription and is release either when an FD is created or -// when the Release method is called. +// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference +// is maintained for each vfs.FileDescription and is release either when an FD +// is created or when the Release method is called. +// +// +stateify savable type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..be418df2e 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 case linux.SO_LINGER: optlen = syscall.SizeofLinger @@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 } case linux.SOL_TCP: switch name { - case linux.TCP_NODELAY: + case linux.TCP_NODELAY, linux.TCP_INQ: optlen = sizeofInt32 } } @@ -513,24 +513,48 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { - case syscall.SOL_IP: + case linux.SOL_SOCKET: switch unixCmsg.Header.Type { - case syscall.IP_TOS: + case linux.SO_TIMESTAMP: + controlMessages.IP.HasTimestamp = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp) + } + + case linux.SOL_IP: + switch unixCmsg.Header.Type { + case linux.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) - case syscall.IP_PKTINFO: + case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + controlMessages.IP.PacketInfo = packetInfo + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr } - case syscall.SOL_IPV6: + case linux.SOL_IPV6: switch unixCmsg.Header.Type { - case syscall.IPV6_TCLASS: + case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + } + + case linux.SOL_TCP: + switch unixCmsg.Header.Type { + case linux.TCP_INQ: + controlMessages.IP.HasInq = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq) } } } diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 163af329b..9a2cac40b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// +stateify savable type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &socketVFS2{ diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index faa61160e..7e7857ac3 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { + return syserror.EACCES +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return syserror.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(enabled bool) error { +func (s *Stack) SetTCPSACKEnabled(bool) error { return syserror.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { +func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return syserror.EACCES } @@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { } if rawLine == "" { - return fmt.Errorf("Failed to get raw line") + return fmt.Errorf("failed to get raw line") } parts := strings.SplitN(rawLine, ":", 2) if len(parts) != 2 { - return fmt.Errorf("Failed to get prefix from: %q", rawLine) + return fmt.Errorf("failed to get prefix from: %q", rawLine) } sliceStat = toSlice(stat) fields := strings.Fields(strings.TrimSpace(parts[1])) if len(fields) != len(sliceStat) { - return fmt.Errorf("Failed to parse fields: %q", rawLine) + return fmt.Errorf("failed to parse fields: %q", rawLine) } if _, ok := stat.(*inet.StatSNMPTCP); ok { snmpTCP = true @@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) } if err != nil { - return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err) } } @@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { } // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { +func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 549787955..e0976fed0 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf // marshalTarget and unmarshalTarget can be used. type targetMaker interface { // id uniquely identifies the target. - id() stack.TargetID + id() targetID - // marshal converts from a stack.Target to an ABI struct. - marshal(target stack.Target) []byte + // marshal converts from a target to an ABI struct. + marshal(target target) []byte - // unmarshal converts from the ABI matcher struct to a stack.Target. - unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) + // unmarshal converts from the ABI matcher struct to a target. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) } -// targetMakers maps the TargetID of supported targets to the targetMaker that +// A targetID uniquely identifies a target. +type targetID struct { + // name is the target name as stored in the xt_entry_target struct. + name string + + // networkProtocol is the protocol to which the target applies. + networkProtocol tcpip.NetworkProtocolNumber + + // revision is the version of the target. + revision uint8 +} + +// target extends a stack.Target, allowing it to be used with the extension +// system. The sentry only uses targets, never stack.Targets directly. +type target interface { + stack.Target + id() targetID +} + +// targetMakers maps the targetID of supported targets to the targetMaker that // marshals and unmarshals it. It is immutable after package initialization. -var targetMakers = map[stack.TargetID]targetMaker{} +var targetMakers = map[targetID]targetMaker{} func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) { - tid := stack.TargetID{ - Name: name, - NetworkProtocol: netProto, - Revision: rev, + tid := targetID{ + name: name, + networkProtocol: netProto, + revision: rev, } if _, ok := targetMakers[tid]; !ok { return 0, false @@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8 // Return the highest supported revision unless rev is higher. for _, other := range targetMakers { otherID := other.id() - if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev { - rev = uint8(otherID.Revision) + if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev { + rev = uint8(otherID.revision) } } return rev, true @@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) { targetMakers[tm.id()] = tm } -func marshalTarget(target stack.Target) []byte { - targetMaker, ok := targetMakers[target.ID()] +func marshalTarget(tgt stack.Target) []byte { + // The sentry only uses targets, never stack.Targets directly. + target := tgt.(target) + targetMaker, ok := targetMakers[target.id()] if !ok { - panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID())) + panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id())) } return targetMaker.marshal(target) } -func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) { - tid := stack.TargetID{ - Name: target.Name.String(), - NetworkProtocol: filter.NetworkProtocol(), - Revision: target.Revision, +func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) { + tid := targetID{ + name: target.Name.String(), + networkProtocol: filter.NetworkProtocol(), + revision: target.Revision, } targetMaker, ok := targetMakers[tid] if !ok { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index b560fae0d..70c561cce 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), false) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct. - entries, info := getEntries4(table, tablename) + entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 4253f7bf4..5dbb604f0 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), true) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct, which is the same in IPv4 and IPv6. - entries, info := getEntries6(table, tablename) + entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 904a12e38..b283d7229 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) { } } +// Table names. +const ( + natTable = "nat" + mangleTable = "mangle" + filterTable = "filter" +) + +// nameToID is immutable. +var nameToID = map[string]stack.TableID{ + natTable: stack.NATID, + mangleTable: stack.MangleID, + filterTable: stack.FilterID, +} + +// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for +// compatibility with netfilter extensions. +func DefaultLinuxTables() *stack.IPTables { + tables := stack.DefaultTables() + tables.VisitTargets(func(oldTarget stack.Target) stack.Target { + switch val := oldTarget.(type) { + case *stack.AcceptTarget: + return &acceptTarget{AcceptTarget: *val} + case *stack.DropTarget: + return &dropTarget{DropTarget: *val} + case *stack.ErrorTarget: + return &errorTarget{ErrorTarget: *val} + case *stack.UserChainTarget: + return &userChainTarget{UserChainTarget: *val} + case *stack.ReturnTarget: + return &returnTarget{ReturnTarget: *val} + case *stack.RedirectTarget: + return &redirectTarget{RedirectTarget: *val} + default: + panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val)) + } + }) + return tables +} + // GetInfo returns information about iptables. func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. @@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.FilterTable: + case filterTable: table = stack.EmptyFilterTable() - case stack.NATTable: + case natTable: table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) @@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } if offset == replace.Underflow[hook] { if !validUnderflow(table.Rules[ruleIdx], ipv6) { - nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) + nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx) return syserr.ErrInvalidArgument } table.Underflows[hk] = ruleIdx @@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6)) - + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool { return false } switch rule.Target.(type) { - case *stack.AcceptTarget, *stack.DropTarget: + case *acceptTarget, *dropTarget: return true default: return false @@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { if !validUnderflow(rule, ipv6) { return false } - _, ok := rule.Target.(*stack.AcceptTarget) + _, ok := rule.Target.(*acceptTarget) return ok } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 0e14447fe..f2653d523 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -26,6 +26,15 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port and/or IP for packets. +const RedirectTargetName = "REDIRECT" + func init() { // Standard targets include ACCEPT, DROP, RETURN, and JUMP. registerTargetMaker(&standardTargetMaker{ @@ -52,25 +61,96 @@ func init() { }) } +// The stack package provides some basic, useful targets for us. The following +// types wrap them for compatibility with the extension system. + +type acceptTarget struct { + stack.AcceptTarget +} + +func (at *acceptTarget) id() targetID { + return targetID{ + networkProtocol: at.NetworkProtocol, + } +} + +type dropTarget struct { + stack.DropTarget +} + +func (dt *dropTarget) id() targetID { + return targetID{ + networkProtocol: dt.NetworkProtocol, + } +} + +type errorTarget struct { + stack.ErrorTarget +} + +func (et *errorTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: et.NetworkProtocol, + } +} + +type userChainTarget struct { + stack.UserChainTarget +} + +func (uc *userChainTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: uc.NetworkProtocol, + } +} + +type returnTarget struct { + stack.ReturnTarget +} + +func (rt *returnTarget) id() targetID { + return targetID{ + networkProtocol: rt.NetworkProtocol, + } +} + +type redirectTarget struct { + stack.RedirectTarget + + // addr must be (un)marshalled when reading and writing the target to + // userspace, but does not affect behavior. + addr tcpip.Address +} + +func (rt *redirectTarget) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rt.NetworkProtocol, + } +} + type standardTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (sm *standardTargetMaker) id() stack.TargetID { +func (sm *standardTargetMaker) id() targetID { // Standard targets have the empty string as a name and no revisions. - return stack.TargetID{ - NetworkProtocol: sm.NetworkProtocol, + return targetID{ + networkProtocol: sm.NetworkProtocol, } } -func (*standardTargetMaker) marshal(target stack.Target) []byte { + +func (*standardTargetMaker) marshal(target target) []byte { // Translate verdicts the same way as the iptables tool. var verdict int32 switch tg := target.(type) { - case *stack.AcceptTarget: + case *acceptTarget: verdict = -linux.NF_ACCEPT - 1 - case *stack.DropTarget: + case *dropTarget: verdict = -linux.NF_DROP - 1 - case *stack.ReturnTarget: + case *returnTarget: verdict = linux.NF_RETURN case *JumpTarget: verdict = int32(tg.Offset) @@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTStandardTarget { nflog("buf has wrong size for standard target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -114,20 +194,20 @@ type errorTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (em *errorTargetMaker) id() stack.TargetID { +func (em *errorTargetMaker) id() targetID { // Error targets have no revision. - return stack.TargetID{ - Name: stack.ErrorTargetName, - NetworkProtocol: em.NetworkProtocol, + return targetID{ + name: ErrorTargetName, + networkProtocol: em.NetworkProtocol, } } -func (*errorTargetMaker) marshal(target stack.Target) []byte { +func (*errorTargetMaker) marshal(target target) []byte { var errorName string switch tg := target.(type) { - case *stack.ErrorTarget: - errorName = stack.ErrorTargetName - case *stack.UserChainTarget: + case *errorTarget: + errorName = ErrorTargetName + case *userChainTarget: errorName = tg.Name default: panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target)) @@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte { }, } copy(xt.Name[:], errorName) - copy(xt.Target.Name[:], stack.ErrorTargetName) + copy(xt.Target.Name[:], ErrorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTErrorTarget { nflog("buf has insufficient size for error target %d", len(buf)) return nil, syserr.ErrInvalidArgument } - var errorTarget linux.XTErrorTarget + var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &errTgt) // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named stack.ErrorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. + // * An actual error case. These rules have an error named + // ErrorTargetName. The last entry of the table is usually an error + // case to catch any packets that somehow fall through every rule. // * To mark the start of a user defined chain. These // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case stack.ErrorTargetName: - return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil + switch name := errTgt.Name.String(); name { + case ErrorTargetName: + return &errorTarget{stack.ErrorTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }}, nil default: // User defined chain. - return &stack.UserChainTarget{ + return &userChainTarget{stack.UserChainTarget{ Name: name, NetworkProtocol: filter.NetworkProtocol(), - }, nil + }}, nil } } @@ -178,22 +259,22 @@ type redirectTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *redirectTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *redirectTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*redirectTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*redirectTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) // This is a redirect target named redirect xt := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTRedirectTarget, }, } - copy(xt.Target.Name[:], stack.RedirectTargetName) + copy(xt.Target.Name[:], RedirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 @@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) < linux.SizeOfXTRedirectTarget { nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - var redirectTarget linux.XTRedirectTarget + var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &rt) // Copy linux.XTRedirectTarget to stack.RedirectTarget. - target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()} + target := redirectTarget{RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }} // RangeSize should be 1. - nfRange := redirectTarget.NfRange + nfRange := rt.NfRange if nfRange.RangeSize != 1 { nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize) return nil, syserr.ErrInvalidArgument @@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) target.Port = ntohs(nfRange.RangeIPV4.MinPort) return &target, nil @@ -264,15 +347,15 @@ type nfNATTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *nfNATTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *nfNATTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*nfNATTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*nfNATTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) nt := nfNATTarget{ Target: linux.XTEntryTarget{ TargetSize: nfNATMarhsalledSize, @@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, }, } - copy(nt.Target.Name[:], stack.RedirectTargetName) - copy(nt.Range.MinAddr[:], rt.Addr) - copy(nt.Range.MaxAddr[:], rt.Addr) + copy(nt.Target.Name[:], RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.addr) + copy(nt.Range.MaxAddr[:], rt.addr) nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto @@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, nt) } -func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if size := nfNATMarhsalledSize; len(buf) < size { nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) return nil, syserr.ErrInvalidArgument @@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta return nil, syserr.ErrInvalidArgument } - target := stack.RedirectTarget{ - NetworkProtocol: filter.NetworkProtocol(), - Addr: tcpip.Address(natRange.MinAddr[:]), - Port: ntohs(natRange.MinProto), + target := redirectTarget{ + RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Port: ntohs(natRange.MinProto), + }, + addr: tcpip.Address(natRange.MinAddr[:]), } return &target, nil @@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { +func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return &stack.AcceptTarget{NetworkProtocol: netProto}, nil + return &acceptTarget{stack.AcceptTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_DROP - 1: - return &stack.DropTarget{NetworkProtocol: netProto}, nil + return &dropTarget{stack.DropTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_QUEUE - 1: nflog("unsupported iptables verdict QUEUE") return nil, syserr.ErrInvalidArgument case linux.NF_RETURN: - return &stack.ReturnTarget{NetworkProtocol: netProto}, nil + return &returnTarget{stack.ReturnTarget{ + NetworkProtocol: netProto, + }}, nil default: nflog("unknown iptables verdict %d", val) return nil, syserr.ErrInvalidArgument @@ -382,9 +473,9 @@ type JumpTarget struct { } // ID implements Target.ID. -func (jt *JumpTarget) ID() stack.TargetID { - return stack.TargetID{ - NetworkProtocol: jt.NetworkProtocol, +func (jt *JumpTarget) id() targetID { + return targetID{ + networkProtocol: jt.NetworkProtocol, } } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 844acfede..352c51390 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.TCPProtocolNumber { - return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber) } return &TCPMatcher{ diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 63201201c..c88d8268d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber) } return &UDPMatcher{ diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index e8930f031..f061c5d62 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index c84d8bd7c..f4d034c13 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -36,9 +36,9 @@ type commandKind int const ( kindNew commandKind = 0x0 - kindDel = 0x1 - kindGet = 0x2 - kindSet = 0x3 + kindDel commandKind = 0x1 + kindGet commandKind = 0x2 + kindSet commandKind = 0x3 ) func typeKind(typ uint16) commandKind { @@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } attrs = rest + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We add the local interface address here + // and ignore the IFA_ADDRESS. switch ahdr.Type { case linux.IFA_LOCAL: err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ @@ -439,11 +444,60 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } else if err != nil { return syserr.ErrInvalidArgument } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported } } return nil } +// delAddr handles RTM_DELADDR requests. +func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We use the local interface address to + // remove the address and ignore the IFA_ADDRESS. + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err != nil { + return syserr.ErrBadLocalAddress + } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported + } + } + + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { hdr := msg.Header() @@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: return p.newAddr(ctx, msg, ms) + case linux.RTM_DELADDR: + return p.delAddr(ctx, msg, ms) default: return syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 3baad098b..057f4d294 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -120,9 +120,6 @@ type socketOpsCommon struct { // fixed buffer but only consume this many bytes. sendBufferSize uint32 - // passcred indicates if this socket wants SCM credentials. - passcred bool - // filter indicates that this socket has a BPF filter "installed". // // TODO(gvisor.dev/issue/1119): We don't actually support filtering, @@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { // Passcred implements transport.Credentialer.Passcred. func (s *socketOpsCommon) Passcred() bool { - s.mu.Lock() - passcred := s.passcred - s.mu.Unlock() - return passcred + return s.ep.SocketOptions().GetPassCred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. @@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] } passcred := usermem.ByteOrder.Uint32(opt) - s.mu.Lock() - s.passcred = passcred != 0 - s.mu.Unlock() + s.ep.SocketOptions().SetPassCred(passcred != 0) return nil case linux.SO_ATTACH_FILTER: diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index c83b23242..461d524e5 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -37,6 +37,8 @@ import ( // to/from the kernel. // // SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 211f07947..23d5cab9c 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -84,69 +84,95 @@ var Metrics = tcpip.Stats{ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), ICMP: tcpip.ICMPStats{ - V4PacketsSent: tcpip.ICMPv4SentPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + V4: tcpip.ICMPv4Stats{ + PacketsSent: tcpip.ICMPv4SentPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), }, - V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + V6: tcpip.ICMPv6Stats{ + PacketsSent: tcpip.ICMPv6SentPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - V6PacketsSent: tcpip.ICMPv6SentPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + IGMP: tcpip.IGMPStats{ + PacketsSent: tcpip.IGMPSentPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."), }, - V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + PacketsReceived: tcpip.IGMPReceivedPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."), + ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."), }, }, IP: tcpip.IPStats{ @@ -209,18 +235,6 @@ const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) -// ntohs converts a 16-bit number from network byte order to host byte order. It -// assumes that the host is little endian. -func ntohs(v uint16) uint16 { - return v<<8 | v>>8 -} - -// htons converts a 16-bit number from host byte order to network byte order. It -// assumes that the host is little endian. -func htons(v uint16) uint16 { - return ntohs(v) -} - // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. type commonEndpoint interface { @@ -240,10 +254,6 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and - // transport.Endpoint.SetSockOptBool. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -252,16 +262,21 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and - // transport.Endpoint.GetSockOpt. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) - // LastError implements tcpip.Endpoint.LastError. + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + + // LastError implements tcpip.Endpoint.LastError and + // transport.Endpoint.LastError. LastError() *tcpip.Error + + // SocketOptions implements tcpip.Endpoint.SocketOptions and + // transport.Endpoint.SocketOptions. + SocketOptions() *tcpip.SocketOptions } // LINT.IfChange @@ -305,7 +320,7 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages + readCM socket.IPControlMessages sender tcpip.FullAddress linkPacketInfo tcpip.LinkPacketInfo @@ -329,9 +344,7 @@ type socketOpsCommon struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } dirent := socket.NewDirent(t, netstackDevice) @@ -360,88 +373,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, -// AF_INET6, and AF_PACKET addresses. -// -// AddressAndFamily returns an address and its family. -func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { - // Make sure we have at least 2 bytes for the address family. - if len(addr) < 2 { - return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument - } - - // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { - case linux.AF_UNIX: - path := addr[2:] - if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - // Drop the terminating NUL (if one exists) and everything after - // it for filesystem (non-abstract) addresses. - if len(path) > 0 && path[0] != 0 { - if n := bytes.IndexByte(path[1:], 0); n >= 0 { - path = path[:n+1] - } - } - return tcpip.FullAddress{ - Addr: tcpip.Address(path), - }, family, nil - - case linux.AF_INET: - var a linux.SockAddrInet - if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - return out, family, nil - - case linux.AF_INET6: - var a linux.SockAddrInet6 - if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - if isLinkLocal(out.Addr) { - out.NIC = tcpip.NICID(a.Scope_id) - } - return out, family, nil - - case linux.AF_PACKET: - var a linux.SockAddrLink - if len(addr) < sockAddrLinkSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) - if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/173): Return protocol too. - return tcpip.FullAddress{ - NIC: tcpip.NICID(a.InterfaceIndex), - Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), - }, family, nil - - case linux.AF_UNSPEC: - return tcpip.FullAddress{}, family, nil - - default: - return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported - } -} - func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -477,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = v - s.readCM = cms + s.readCM = socket.NewIPControlMessages(s.family, cms) atomic.StoreUint32(&s.readViewHasData, 1) return nil @@ -497,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { return } - var v tcpip.LingerOption - if err := s.Endpoint.GetSockOpt(&v); err != nil { - return - } - + v := s.Endpoint.SocketOptions().GetLinger() // The case for zero timeout is handled in tcp endpoint close function. // Close is blocked until either: // 1. The endpoint state is not in any of the states: FIN-WAIT1, @@ -718,11 +645,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -746,7 +669,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -827,7 +750,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } } else { var err *syserr.Error - addr, family, err = AddressAndFamily(sockaddr) + addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -918,7 +841,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1002,7 +925,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptSocket(t, s, ep, family, skType, name, outLen) case linux.SOL_TCP: - return getSockOptTCP(t, ep, name, outLen) + return getSockOptTCP(t, s, ep, name, outLen) case linux.SOL_IPV6: return getSockOptIPv6(t, s, ep, name, outPtr, outLen) @@ -1038,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.LastError() + err := ep.SocketOptions().GetLastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1065,13 +988,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.PasscredOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred())) + return &v, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1112,25 +1030,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress())) + return &v, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReusePortOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort())) + return &v, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1163,37 +1072,24 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.BroadcastOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetBroadcast())) + return &v, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive())) + return &v, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - var v tcpip.LingerOption var linger linux.Linger - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetLinger() if v.Enabled { linger.OnOff = 1 @@ -1224,22 +1120,26 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.OutOfBandInlineOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline())) + return &v, nil + + case linux.SO_NO_CHECK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(v) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum())) + return &v, nil - case linux.SO_NO_CHECK: + case linux.SO_ACCEPTCONN: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen } vP := primitive.Int32(boolToInt32(v)) return &vP, nil @@ -1251,46 +1151,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(!v)) - return &vP, nil + v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption())) + return &v, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.CorkOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption())) + return &v, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.QuickAckOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck())) + return &v, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1464,19 +1354,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1508,13 +1403,16 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil + + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1526,7 +1424,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: @@ -1535,7 +1433,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1555,7 +1453,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1575,7 +1473,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1597,6 +1495,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1639,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) return &a.(*linux.SockAddrInet).Addr, nil @@ -1648,13 +1551,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1678,26 +1576,32 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil + + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil - case linux.IP_PKTINFO: + case linux.IP_HDRINCL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil + + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -1709,7 +1613,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil case linux.IPT_SO_GET_INFO: @@ -1816,7 +1720,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptSocket(t, s, ep, name, optVal) case linux.SOL_TCP: - return setSockOptTCP(t, ep, name, optVal) + return setSockOptTCP(t, s, ep, name, optVal) case linux.SOL_IPV6: return setSockOptIPv6(t, s, ep, name, optVal) @@ -1866,7 +1770,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) + ep.SocketOptions().SetReuseAddress(v != 0) + return nil case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1874,7 +1779,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) + ep.SocketOptions().SetReusePort(v != 0) + return nil case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1904,7 +1810,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) + ep.SocketOptions().SetBroadcast(v != 0) + return nil case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1912,7 +1819,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) + ep.SocketOptions().SetPassCred(v != 0) + return nil case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1920,7 +1828,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) + ep.SocketOptions().SetKeepAlive(v != 0) + return nil case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1959,8 +1868,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - opt := tcpip.OutOfBandInlineOption(v) - return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) + ep.SocketOptions().SetOutOfBandInline(v != 0) + return nil case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1968,7 +1877,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + ep.SocketOptions().SetNoChecksum(v != 0) + return nil case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { @@ -1982,10 +1892,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError( - ep.SetSockOpt(&tcpip.LingerOption{ - Enabled: v.OnOff != 0, - Timeout: time.Second * time.Duration(v.Linger)})) + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger), + }) + return nil case linux.SO_DETACH_FILTER: // optval is ignored. @@ -2000,7 +1911,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. -func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if len(optVal) < sizeOfInt32 { @@ -2008,7 +1924,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) + ep.SocketOptions().SetDelayOption(v == 0) + return nil case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -2016,7 +1933,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) + ep.SocketOptions().SetCorkOption(v != 0) + return nil case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -2024,7 +1942,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) + ep.SocketOptions().SetQuickAck(v != 0) + return nil case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -2136,14 +2055,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, @@ -2163,6 +2099,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2182,7 +2127,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2190,7 +2136,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2265,6 +2211,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2317,7 +2268,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), - InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), + InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]), })) case linux.IP_MULTICAST_LOOP: @@ -2326,7 +2277,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2362,7 +2314,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2372,7 +2325,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2382,7 +2336,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil + + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2422,7 +2389,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2500,7 +2466,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2524,7 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -2571,72 +2535,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { } } -// isLinkLocal determines if the given IPv6 address is link-local. This is the -// case when it has the fe80::/10 prefix. This check is used to determine when -// the NICID is relevant for a given IPv6 address. -func isLinkLocal(addr tcpip.Address) bool { - return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 -} - -// ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { - switch family { - case linux.AF_UNIX: - var out linux.SockAddrUnix - out.Family = linux.AF_UNIX - l := len([]byte(addr.Addr)) - for i := 0; i < l; i++ { - out.Path[i] = int8(addr.Addr[i]) - } - - // Linux returns the used length of the address struct (including the - // null terminator) for filesystem paths. The Family field is 2 bytes. - // It is sometimes allowed to exclude the null terminator if the - // address length is the max. Abstract and empty paths always return - // the full exact length. - if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return &out, uint32(2 + l) - } - return &out, uint32(3 + l) - - case linux.AF_INET: - var out linux.SockAddrInet - copy(out.Addr[:], addr.Addr) - out.Family = linux.AF_INET - out.Port = htons(addr.Port) - return &out, uint32(sockAddrInetSize) - - case linux.AF_INET6: - var out linux.SockAddrInet6 - if len(addr.Addr) == header.IPv4AddressSize { - // Copy address in v4-mapped format. - copy(out.Addr[12:], addr.Addr) - out.Addr[10] = 0xff - out.Addr[11] = 0xff - } else { - copy(out.Addr[:], addr.Addr) - } - out.Family = linux.AF_INET6 - out.Port = htons(addr.Port) - if isLinkLocal(addr.Addr) { - out.Scope_id = uint32(addr.NIC) - } - return &out, uint32(sockAddrInet6Size) - - case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. - var out linux.SockAddrLink - out.Family = linux.AF_PACKET - out.InterfaceIndex = int32(addr.NIC) - out.HardwareAddrLen = header.EthernetAddressSize - copy(out.HardwareAddr[:], addr.Addr) - return &out, uint32(sockAddrLinkSize) - - default: - return nil, 0 - } -} - // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -2645,7 +2543,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2657,7 +2555,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2675,7 +2573,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ // Always do at least one fetchReadView, even if the number of bytes to // read is 0. err = s.fetchReadView() - if err != nil { + if err != nil || len(s.readView) == 0 { break } if dst.NumBytes() == 0 { @@ -2698,15 +2596,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ } copied += n s.readView.TrimFront(n) - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) - } dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) break } + // If we are done reading requested data then stop. + if dst.NumBytes() == 0 { + break + } + } + + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) } // If we managed to copy something, we must deliver it. @@ -2801,10 +2704,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { - addr, addrLen = ConvertAddress(s.family, s.sender) + addr, addrLen = socket.ConvertAddress(s.family, s.sender) switch v := addr.(type) { case *linux.SockAddrLink: - v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } @@ -2822,7 +2725,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // We need to peek beyond the first message. dst = dst.DropFirst(n) num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) { - n, _, err := s.Endpoint.Peek(dsts) + n, err := s.Endpoint.Peek(dsts) // TODO(b/78348848): Handle peek timestamp. if err != nil { return int64(n), syserr.TranslateNetstackError(err).ToError() @@ -2866,15 +2769,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ - IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + IP: socket.IPControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + OriginalDstAddress: s.readCM.OriginalDstAddress, }, } } @@ -2969,7 +2873,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, family, err := AddressAndFamily(to) + addrBuf, family, err := socket.AddressAndFamily(to) if err != nil { return 0, err } @@ -3388,6 +3292,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3397,7 +3313,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3426,7 +3342,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3436,7 +3352,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 4c6791fff..b756bfca0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -35,6 +35,8 @@ import ( // SocketVFS2 encapsulates all the state needed to represent a network stack // endpoint in the kernel context. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -49,13 +51,11 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &SocketVFS2{ @@ -189,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addrLen uint32 if peerAddr != nil { // Get address of the peer and write it to peer slice. - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index ead3b2b79..c847ff1c7 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 2a01143f6..0af805246 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 1028d2a6e..cc0fadeb5 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } -// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +// convertAddr converts an InterfaceAddr to a ProtocolAddress. +func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) { var ( - protocol tcpip.NetworkProtocolNumber - address tcpip.Address + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + protocolAddress tcpip.ProtocolAddress ) switch addr.Family { case linux.AF_INET: - if len(addr.Addr) < header.IPv4AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv4AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv4AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv4.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) - + address = tcpip.Address(addr.Addr) case linux.AF_INET6: - if len(addr.Addr) < header.IPv6AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv6AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv6AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv6.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) - + address = tcpip.Address(addr.Addr) default: - return syserror.ENOTSUP + return protocolAddress, syserror.ENOTSUP } - protocolAddress := tcpip.ProtocolAddress{ + protocolAddress = tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: int(addr.PrefixLen), }, } + return protocolAddress, nil +} + +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } // Attach address to interface. - if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + nicID := tcpip.NICID(idx) + if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network if it doesn't exist already. + localRoute := tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: nicID, + } + + for _, rt := range s.Stack.GetRouteTable() { + if rt.Equal(localRoute) { + return nil + } + } + + // Local route does not exist yet. Add it. + s.Stack.AddRoute(localRoute) + + return nil +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } + + // Remove addresses matching the address and prefix. + nicID := tcpip.NICID(idx) + if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil { return syserr.TranslateNetstackError(err).ToError() } - // Add route for local network. - s.Stack.AddRoute(tcpip.Route{ + // Remove the corresponding local network route if it exists. + localRoute := tcpip.Route{ Destination: protocolAddress.AddressWithPrefix.Subnet(), Gateway: "", // No gateway for local network. - NIC: tcpip.NICID(idx), + NIC: nicID, + } + s.Stack.RemoveRoutes(func(rt tcpip.Route) bool { + return rt.Equal(localRoute) }) + return nil } @@ -279,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: - in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats - out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ 0, // Icmp/InMsgs. - Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors. 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. @@ -298,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. 0, // Icmp/OutMsgs. - Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. - out.DstUnreachable.Value(), // OutDestUnreachs. - out.TimeExceeded.Value(), // OutTimeExcds. - out.ParamProblem.Value(), // OutParmProbs. - out.SrcQuench.Value(), // OutSrcQuenchs. - out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. - out.EchoReply.Value(), // OutEchoReps. - out.Timestamp.Value(), // OutTimestamps. - out.TimestampReply.Value(), // OutTimestampReps. - out.InfoRequest.Value(), // OutAddrMasks. - out.InfoReply.Value(), // OutAddrMaskReps. + Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. } case *inet.StatSNMPTCP: tcp := Metrics.TCP diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fd31479e5..bcc426e33 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -18,6 +18,7 @@ package socket import ( + "bytes" "fmt" "sync/atomic" "syscall" @@ -35,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/usermem" ) @@ -42,7 +44,79 @@ import ( // control messages. type ControlMessages struct { Unix transport.ControlMessages - IP tcpip.ControlMessages + IP IPControlMessages +} + +// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format. +func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + return p +} + +// NewIPControlMessages converts the tcpip ControlMessgaes (which does not +// have Linux specific format) to Linux format. +func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { + var orgDstAddr linux.SockAddr + if cmgs.HasOriginalDstAddress { + orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) + } + return IPControlMessages{ + HasTimestamp: cmgs.HasTimestamp, + Timestamp: cmgs.Timestamp, + HasInq: cmgs.HasInq, + Inq: cmgs.Inq, + HasTOS: cmgs.HasTOS, + TOS: cmgs.TOS, + HasTClass: cmgs.HasTClass, + TClass: cmgs.TClass, + HasIPPacketInfo: cmgs.HasIPPacketInfo, + PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + OriginalDstAddress: orgDstAddr, + } +} + +// IPControlMessages contains socket control messages for IP sockets. +// This can contain Linux specific structures unlike tcpip.ControlMessages. +// +// +stateify savable +type IPControlMessages struct { + // HasTimestamp indicates whether Timestamp is valid/set. + HasTimestamp bool + + // Timestamp is the time (in ns) that the last packet used to create + // the read data was received. + Timestamp int64 + + // HasInq indicates whether Inq is valid/set. + HasInq bool + + // Inq is the number of bytes ready to be received. + Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS uint8 + + // HasTClass indicates whether TClass is valid/set. + HasTClass bool + + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo linux.ControlMessageIPPacketInfo + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress linux.SockAddr } // Release releases Unix domain socket credentials and rights. @@ -460,3 +534,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { panic(fmt.Sprintf("Unsupported socket family %v", family)) } } + +var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes() +var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes() +var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes() + +// Ntohs converts a 16-bit number from network byte order to host byte order. It +// assumes that the host is little endian. +func Ntohs(v uint16) uint16 { + return v<<8 | v>>8 +} + +// Htons converts a 16-bit number from host byte order to network byte order. It +// assumes that the host is little endian. +func Htons(v uint16) uint16 { + return Ntohs(v) +} + +// isLinkLocal determines if the given IPv6 address is link-local. This is the +// case when it has the fe80::/10 prefix. This check is used to determine when +// the NICID is relevant for a given IPv6 address. +func isLinkLocal(addr tcpip.Address) bool { + return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 +} + +// ConvertAddress converts the given address to a native format. +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { + switch family { + case linux.AF_UNIX: + var out linux.SockAddrUnix + out.Family = linux.AF_UNIX + l := len([]byte(addr.Addr)) + for i := 0; i < l; i++ { + out.Path[i] = int8(addr.Addr[i]) + } + + // Linux returns the used length of the address struct (including the + // null terminator) for filesystem paths. The Family field is 2 bytes. + // It is sometimes allowed to exclude the null terminator if the + // address length is the max. Abstract and empty paths always return + // the full exact length. + if l == 0 || out.Path[0] == 0 || l == len(out.Path) { + return &out, uint32(2 + l) + } + return &out, uint32(3 + l) + + case linux.AF_INET: + var out linux.SockAddrInet + copy(out.Addr[:], addr.Addr) + out.Family = linux.AF_INET + out.Port = Htons(addr.Port) + return &out, uint32(sockAddrInetSize) + + case linux.AF_INET6: + var out linux.SockAddrInet6 + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. + copy(out.Addr[12:], addr.Addr) + out.Addr[10] = 0xff + out.Addr[11] = 0xff + } else { + copy(out.Addr[:], addr.Addr) + } + out.Family = linux.AF_INET6 + out.Port = Htons(addr.Port) + if isLinkLocal(addr.Addr) { + out.Scope_id = uint32(addr.NIC) + } + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(gvisor.dev/issue/173): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + + default: + return nil, 0 + } +} + +// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the +// netstack representation taking any addresses into account. +func BytesToIPAddress(addr []byte) tcpip.Address { + if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { + return "" + } + return tcpip.Address(addr) +} + +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. +// +// AddressAndFamily returns an address and its family. +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { + // Make sure we have at least 2 bytes for the address family. + if len(addr) < 2 { + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument + } + + // Get the rest of the fields based on the address family. + switch family := usermem.ByteOrder.Uint16(addr); family { + case linux.AF_UNIX: + path := addr[2:] + if len(path) > linux.UnixPathMax { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + // Drop the terminating NUL (if one exists) and everything after + // it for filesystem (non-abstract) addresses. + if len(path) > 0 && path[0] != 0 { + if n := bytes.IndexByte(path[1:], 0); n >= 0 { + path = path[:n+1] + } + } + return tcpip.FullAddress{ + Addr: tcpip.Address(path), + }, family, nil + + case linux.AF_INET: + var a linux.SockAddrInet + if len(addr) < sockAddrInetSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + return out, family, nil + + case linux.AF_INET6: + var a linux.SockAddrInet6 + if len(addr) < sockAddrInet6Size { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + if isLinkLocal(out.Addr) { + out.NIC = tcpip.NICID(a.Scope_id) + } + return out, family, nil + + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/173): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, family, nil + + default: + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported + } +} diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cc7408698..cce0acc33 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "socket_refs.go", package = "unix", prefix = "socketOperations", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketOperations", }, @@ -19,7 +19,7 @@ go_template_instance( out = "socket_vfs2_refs.go", package = "unix", prefix = "socketVFS2", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketVFS2", }, @@ -43,6 +43,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 26c3a51b9..3ebbd28b0 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "queue_refs.go", package = "transport", prefix = "queue", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "queue", }, @@ -44,6 +44,7 @@ go_library( "//pkg/ilist", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index aa4f3c04d..9f7aca305 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -118,33 +118,29 @@ var ( // NewConnectioned creates a new unbound connectionedEndpoint. func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { - return &connectionedEndpoint{ + return newConnectioned(ctx, stype, uid) +} + +func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint { + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { - a := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } - b := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } + a := newConnectioned(ctx, stype, uid) + b := newConnectioned(ctx, stype, uid) q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} - q1.EnableLeakCheck() + q1.InitRefs() q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - q2.EnableLeakCheck() + q2.InitRefs() if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} @@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { - return &connectionedEndpoint{ + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // ID implements ConnectingEndpoint.ID. @@ -298,16 +296,17 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } + ne.ops.InitHandler(ne) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} - readQueue.EnableLeakCheck() + readQueue.InitRefs() ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - writeQueue.EnableLeakCheck() + writeQueue.InitRefs() if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index f8aacca13..0813ad87d 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -42,8 +42,9 @@ var ( func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} - q.EnableLeakCheck() + q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index d6fc03520..099a56281 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,8 +16,6 @@ package transport import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -32,6 +30,8 @@ import ( const initialLimit = 16 * 1024 // A RightsControlMessage is a control message containing FDs. +// +// +stateify savable type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. Clone() RightsControlMessage @@ -178,10 +178,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool sets a socket option for simple cases when a value has - // the int type. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -189,10 +185,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool gets a socket option for simple cases when a return - // value has the int type. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -201,8 +193,12 @@ type Endpoint interface { // procfs. State() uint32 - // LastError implements tcpip.Endpoint.LastError. + // LastError clears and returns the last error reported by the endpoint. LastError() *tcpip.Error + + // SocketOptions returns the structure which contains all the socket + // level options. + SocketOptions() *tcpip.SocketOptions } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket @@ -336,7 +332,7 @@ type Receiver interface { RecvMaxQueueSize() int64 // Release releases any resources owned by the Receiver. It should be - // called before droping all references to a Receiver. + // called before dropping all references to a Receiver. Release(ctx context.Context) } @@ -487,7 +483,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds c := q.control.Clone() // Don't consume data since we are peeking. - copied, data, _ = vecCopy(data, q.buffer) + copied, _, _ = vecCopy(data, q.buffer) return copied, copied, c, false, q.addr, notify, nil } @@ -572,6 +568,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds return copied, copied, c, cmTruncated, q.addr, notify, nil } +// Release implements Receiver.Release. +func (q *streamQueueReceiver) Release(ctx context.Context) { + q.queueReceiver.Release(ctx) + q.control.Release(ctx) +} + // A ConnectedEndpoint is an Endpoint that can be used to send Messages. type ConnectedEndpoint interface { // Passcred implements Endpoint.Passcred. @@ -619,7 +621,7 @@ type ConnectedEndpoint interface { SendMaxQueueSize() int64 // Release releases any resources owned by the ConnectedEndpoint. It should - // be called before droping all references to a ConnectedEndpoint. + // be called before dropping all references to a ConnectedEndpoint. Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to @@ -728,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() { // +stateify savable type baseEndpoint struct { *waiter.Queue - - // passcred specifies whether SCM_CREDENTIALS socket control messages are - // enabled on this endpoint. Must be accessed atomically. - passcred int32 + tcpip.DefaultSocketOptionsHandler // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -747,8 +746,8 @@ type baseEndpoint struct { // or may be used if the endpoint is connected. path string - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption + // ops is used to get socket level options. + ops tcpip.SocketOptions } // EventRegister implements waiter.Waitable.EventRegister. @@ -773,7 +772,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { // Passcred implements Credentialer.Passcred. func (e *baseEndpoint) Passcred() bool { - return atomic.LoadInt32(&e.passcred) != 0 + return e.SocketOptions().GetPassCred() } // ConnectedPasscred implements Credentialer.ConnectedPasscred. @@ -783,14 +782,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool { return e.connected != nil && e.connected.Passcred() } -func (e *baseEndpoint) setPasscred(pc bool) { - if pc { - atomic.StoreInt32(&e.passcred, 1) - } else { - atomic.StoreInt32(&e.passcred, 0) - } -} - // Connected implements ConnectingEndpoint.Connected. func (e *baseEndpoint) Connected() bool { return e.receiver != nil && e.connected != nil @@ -846,24 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - e.linger = *v - e.Unlock() - } - return nil -} - -func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.BroadcastOption: - case tcpip.PasscredOption: - e.setPasscred(v) - case tcpip.ReuseAddressOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } return nil } @@ -877,20 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption: - return false, nil - - case tcpip.PasscredOption: - return e.Passcred(), nil - - default: - log.Warningf("Unsupported socket option: %d", opt) - return false, tcpip.ErrUnknownProtocolOption - } -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -954,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - *o = e.linger - e.Unlock() - return nil - - default: - log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption - } + log.Warningf("Unsupported socket option: %T", opt) + return tcpip.ErrUnknownProtocolOption } // LastError implements Endpoint.LastError. @@ -972,6 +922,11 @@ func (*baseEndpoint) LastError() *tcpip.Error { return nil } +// SocketOptions implements Endpoint.SocketOptions. +func (e *baseEndpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} + // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index a4a76d0a3..c59297c80 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -80,8 +80,7 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty stype: stype, }, } - s.EnableLeakCheck() - + s.InitRefs() return fs.NewFile(ctx, d, flags, &s) } @@ -137,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, family, err := netstack.AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument @@ -170,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -182,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -256,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -648,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -683,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 678355fb9..27f705bb2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // returns a corresponding file description. func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) @@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } + sock.InitRefs() sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ @@ -171,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 0ea4aab8b..563d60578 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -12,10 +12,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/time", + "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 245d2c5cf..167754537 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -19,10 +19,12 @@ import ( "fmt" "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -57,7 +59,7 @@ type SaveOpts struct { } // Save saves the system state. -func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { +func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error { log.Infof("Sandbox save started, pausing all tasks.") k.Pause() k.ReceiveTaskStates() @@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { err = ErrStateFile{err} } else { // Save the kernel. - err = k.SaveTo(wc) + err = k.SaveTo(ctx, wc) // ENOSPC is a state file error. This error can only come from // writing the state file, and not from fs.FileOperations.Fsync @@ -108,7 +110,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { +func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n, clocks) + return k.LoadFrom(ctx, r, n, clocks, vfsOpts) } diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index a920180d3..d36a64ffc 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -32,8 +32,8 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", + "//pkg/sentry/socket", "//pkg/sentry/socket/netlink", - "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", "//pkg/usermem", ], diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index cc5f70cd4..d943a7cb1 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" - "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/usermem" ) @@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := netstack.AddressAndFamily(b) + fa, _, err := socket.AddressAndFamily(b) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 9c9def7cd..cff442846 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 519066a47..8db587401 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } - fSetOwn(t, file, set) + fSetOwn(t, int(fd), file, set) return 0, nil, nil case linux.FIOGETOWN, linux.SIOCGPGRP: @@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 { // // If who is positive, it represents a PID. If negative, it represents a PGID. // If the PID or PGID is invalid, the owner is silently unset. -func fSetOwn(t *kernel.Task, file *fs.File, who int32) error { - a := file.Async(fasync.New).(*fasync.FileAsync) +func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error { + a := file.Async(fasync.New(fd)).(*fasync.FileAsync) if who < 0 { // Check for overflow before flipping the sign. if who-1 > who { @@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN: return uintptr(fGetOwn(t, file)), nil, nil case linux.F_SETOWN: - return 0, nil, fSetOwn(t, file, args[2].Int()) + return 0, nil, fSetOwn(t, int(fd), file, args[2].Int()) case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) @@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - a := file.Async(fasync.New).(*fasync.FileAsync) + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) switch owner.Type { case linux.F_OWNER_TID: task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) @@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } n, err := sz.SetFifoSize(int64(args[2].Int())) return uintptr(n), nil, err + case linux.F_GETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return uintptr(a.Signal()), nil, nil + case linux.F_SETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 849a47476..f7135ea46 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -32,7 +32,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return 0, syserror.EINVAL } - r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize) + r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize) r.SetFlags(linuxToFlags(flags).Settable()) defer r.DecRef(t) diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index 309c183a3..88cd234d1 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -90,6 +90,9 @@ var setableLimits = map[limits.LimitType]struct{}{ limits.FileSize: {}, limits.MemoryLocked: {}, limits.Stack: {}, + // RSS can be set, but it's not enforced because Linux doesn't enforce it + // either: "This limit has effect only in Linux 2.4.x, x < 30" + limits.Rss: {}, // These are not enforced, but we include them here to avoid returning // EPERM, since some apps expect them to succeed. limits.Core: {}, diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 47dadb800..a62a6b3b5 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -129,13 +129,35 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getPID(t, id, num) return uintptr(v), nil, err - case linux.IPC_INFO, - linux.SEM_INFO, - linux.IPC_STAT, + case linux.IPC_STAT: + arg := args[3].Pointer() + ds, err := ipcStat(t, id) + if err == nil { + _, err = ds.CopyOut(t, arg) + } + + return 0, nil, err + + case linux.GETZCNT: + v, err := getZCnt(t, id, num) + return uintptr(v), nil, err + + case linux.GETNCNT: + v, err := getNCnt(t, id, num) + return uintptr(v), nil, err + + case linux.IPC_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.IPCInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + + case linux.SEM_INFO, linux.SEM_STAT, - linux.SEM_STAT_ANY, - linux.GETNCNT, - linux.GETZCNT: + linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -171,6 +193,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP return set.Change(t, creds, owner, perms) } +func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.GetStat(creds) +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) @@ -240,3 +272,23 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) { } return int32(tg.ID()), nil } + +func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountZeroWaiters(num, creds) +} + +func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountNegativeWaiters(num, creds) +} diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index e748d33d8..d639c9bf7 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(target.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) + info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) if err := target.SendGroupSignal(info); err != syserror.ESRCH { return 0, nil, err } @@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) err := tg.SendSignal(info) if err == syserror.ESRCH { // ESRCH is ignored because it means the task @@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) // See note above regarding ESRCH race above. if err := tg.SendSignal(info); err != syserror.ESRCH { lastErr = err @@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI Signo: int32(sig), Code: arch.SignalInfoTkill, } - info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 46616c961..1c4cdb0dd 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -41,6 +41,7 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inCh chan struct{} outCh chan struct{} ) + for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) opts.Length -= n @@ -61,23 +62,28 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inW, _ := waiter.NewChannelEntry(inCh) inFile.EventRegister(&inW, EventMaskRead) defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Need to refresh readiness. + continue } if err = t.Block(inCh); err != nil { break } } - if outFile.Readiness(EventMaskWrite) == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, EventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + // We might be ready to write now. Try again before + // blocking. + continue + } + if err = t.Block(outCh); err != nil { + break } } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 39ca9ea97..8e7ac0ffe 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -159,7 +159,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user defer wd.DecRef(t) } - // Load the new TaskContext. + // Load the new TaskImage. remainingTraversals := uint(linux.MaxSymlinkTraversals) loadArgs := loader.LoadArgs{ Opener: fsbridge.NewFSLookup(t.MountNamespace(), root, wd), @@ -173,12 +173,12 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user Features: t.Arch().FeatureSet(), } - tc, se := t.Kernel().LoadTaskImage(t, loadArgs) + image, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } - ctrl, err := t.Execve(tc) + ctrl, err := t.Execve(image) return 0, ctrl, err } @@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal si := arch.SignalInfo{ Signo: int32(linux.SIGCHLD), } - si.SetPid(int32(wr.TID)) - si.SetUid(int32(wr.UID)) + si.SetPID(int32(wr.TID)) + si.SetUID(int32(wr.UID)) // TODO(b/73541790): convert kernel.ExitStatus to functions and make // WaitResult.Status a linux.WaitStatus. s := syscall.WaitStatus(wr.Status) diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go index c8ce2aabc..7a409620d 100644 --- a/pkg/sentry/syscalls/linux/vfs2/execve.go +++ b/pkg/sentry/syscalls/linux/vfs2/execve.go @@ -109,7 +109,7 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user executable = fsbridge.NewVFSFile(file) } - // Load the new TaskContext. + // Load the new TaskImage. mntns := t.MountNamespaceVFS2() wd := t.FSContext().WorkingDirectoryVFS2() defer wd.DecRef(t) @@ -126,11 +126,11 @@ func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr user Features: t.Arch().FeatureSet(), } - tc, se := t.Kernel().LoadTaskImage(t, loadArgs) + image, se := t.Kernel().LoadTaskImage(t, loadArgs) if se != nil { return 0, nil, se.ToError() } - ctrl, err := t.Execve(tc) + ctrl, err := t.Execve(image) return 0, ctrl, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 36e89700e..7dd9ef857 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) case linux.F_GETOWN_EX: owner, hasOwner := getAsyncOwner(t, file) if !hasOwner { @@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) + return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID) case linux.F_SETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { @@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err case linux.F_SETLK, linux.F_SETLKW: return 0, nil, posixLock(t, args, file, cmd) + case linux.F_GETSIG: + a := file.AsyncHandler() + if a == nil { + // Default behavior aka SIGIO. + return 0, nil, nil + } + return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil + case linux.F_SETSIG: + a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL @@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne } } -func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error { +func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error { switch ownerType { case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP: // Acceptable type. @@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32 return syserror.EINVAL } - a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync) + a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync) if pid == 0 { a.ClearOwner() return nil diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 2806c3f6f..20c264fef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) } ret, err := file.Ioctl(t, t.MemoryManager(), args) diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 035e2a6b0..8bb763a47 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -343,8 +343,8 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Copy data. var ( - n int64 - err error + total int64 + err error ) dw := dualWaiter{ inFile: inFile, @@ -357,13 +357,20 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // can block. nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 if outIsPipe { - for n < count { - var spliceN int64 - spliceN, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count) + for { + var n int64 + n, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count-total) if offset != -1 { - offset += spliceN + offset += n + } + total += n + if total == count { + break + } + if err == nil && t.Interrupted() { + err = syserror.ErrInterrupted + break } - n += spliceN if err == syserror.ErrWouldBlock && !nonBlock { err = dw.waitForBoth(t) } @@ -374,7 +381,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } else { // Read inFile to buffer, then write the contents to outFile. buf := make([]byte, count) - for n < count { + for { var readN int64 if offset != -1 { readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) @@ -382,7 +389,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } else { readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) } - n += readN // Write all of the bytes that we read. This may need // multiple write calls to complete. @@ -398,7 +404,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We didn't complete the write. Only report the bytes that were actually // written, and rewind offsets as needed. notWritten := int64(len(wbuf)) - n -= notWritten + readN -= notWritten if offset == -1 { // We modified the offset of the input file itself during the read // operation. Rewind it. @@ -415,6 +421,16 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc break } } + + total += readN + buf = buf[readN:] + if total == count { + break + } + if err == nil && t.Interrupted() { + err = syserror.ErrInterrupted + break + } if err == syserror.ErrWouldBlock && !nonBlock { err = dw.waitForBoth(t) } @@ -432,7 +448,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } - if n != 0 { + if total != 0 { inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent) outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent) @@ -445,7 +461,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We can only pass a single file to handleIOError, so pick inFile arbitrarily. // This is used only for debugging purposes. - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "sendfile", inFile) + return uintptr(total), nil, slinux.HandleIOErrorVFS2(t, total != 0, err, syserror.ERESTARTSYS, "sendfile", inFile) } // dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not @@ -480,18 +496,17 @@ func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { // waitForOut waits for dw.outfile to be read. func (dw *dualWaiter) waitForOut(t *kernel.Task) error { - if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if dw.outCh == nil { - dw.outW, dw.outCh = waiter.NewChannelEntry(nil) - dw.outFile.EventRegister(&dw.outW, eventMaskWrite) - // We might be ready now. Try again before blocking. - return nil - } - if err := t.Block(dw.outCh); err != nil { - return err - } - } - return nil + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. See b/172075629, b/170743336. + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready to write now. Try again before blocking. + return nil + } + return t.Block(dw.outCh) } // destroy cleans up resources help by dw. No more calls to wait* can occur diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index c855608db..a3868bf16 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "file_description_refs.go", package = "vfs", prefix = "FileDescription", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FileDescription", }, @@ -43,7 +43,7 @@ go_template_instance( out = "mount_namespace_refs.go", package = "vfs", prefix = "MountNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "MountNamespace", }, @@ -54,7 +54,7 @@ go_template_instance( out = "filesystem_refs.go", package = "vfs", prefix = "Filesystem", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Filesystem", }, @@ -87,6 +87,7 @@ go_library( "pathname.go", "permissions.go", "resolving_path.go", + "save_restore.go", "vfs.go", ], visibility = ["//pkg/sentry:internal"], @@ -99,10 +100,12 @@ go_library( "//pkg/gohacks", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/fs/lock", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 8f36c3e3b..072655fe8 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -74,7 +74,7 @@ type epollInterestKey struct { // +stateify savable type epollInterest struct { // epoll is the owning EpollInstance. epoll is immutable. - epoll *EpollInstance + epoll *EpollInstance `state:"wait"` // key is the file to which this epollInterest applies. key is immutable. key epollInterestKey @@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } // Add epi to file.epolls so that it is removed when the last @@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready with the new mask. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } return nil @@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error } // Callback implements waiter.EntryCallback.Callback. -func (epi *epollInterest) Callback(*waiter.Entry) { +func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) { newReady := false epi.epoll.mu.Lock() if !epi.ready { diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 183957ad8..5321ac80a 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -15,12 +15,14 @@ package vfs import ( + "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" @@ -42,7 +44,7 @@ import ( type FileDescription struct { FileDescriptionRefs - // flagsMu protects statusFlags and asyncHandler below. + // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below. flagsMu sync.Mutex `state:"nosave"` // statusFlags contains status flags, "initialized by open(2) and possibly @@ -51,6 +53,11 @@ type FileDescription struct { // access to asyncHandler. statusFlags uint32 + // saved is true after beforeSave is called. This is used to prevent + // double-unregistration of asyncHandler. This does not work properly for + // save-resume, which is not currently supported in gVisor (see b/26588733). + saved bool `state:"nosave"` + // asyncHandler handles O_ASYNC signal generation. It is set with the // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must // also be set by fcntl(2). @@ -133,7 +140,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou } } - fd.EnableLeakCheck() + fd.InitRefs() // Remove "file creation flags" to mirror the behavior from file.f_flags in // fs/open.c:do_dentry_open. @@ -183,8 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. - if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } fd.asyncHandler = nil @@ -584,7 +590,11 @@ func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of if !fd.readable { return 0, syserror.EBADF } - return fd.impl.PRead(ctx, dst, offset, opts) + start := fsmetric.StartReadWait() + n, err := fd.impl.PRead(ctx, dst, offset, opts) + fsmetric.Reads.Increment() + fsmetric.FinishReadWait(fsmetric.ReadWait, start) + return n, err } // Read is similar to PRead, but does not specify an offset. @@ -592,7 +602,11 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt if !fd.readable { return 0, syserror.EBADF } - return fd.impl.Read(ctx, dst, opts) + start := fsmetric.StartReadWait() + n, err := fd.impl.Read(ctx, dst, opts) + fsmetric.Reads.Increment() + fsmetric.FinishReadWait(fsmetric.ReadWait, start) + return n, err } // PWrite writes src to the file represented by fd, starting at the given @@ -826,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn return fd.asyncHandler } -// FileReadWriteSeeker is a helper struct to pass a FileDescription as -// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. -type FileReadWriteSeeker struct { - FD *FileDescription - Ctx context.Context - ROpts ReadOptions - WOpts WriteOptions -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) - return int(n), err -} - -// Read implements io.ReadWriteSeeker.Read. -func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.Read(f.Ctx, dst, f.ROpts) - return int(n), err -} - -// Seek implements io.ReadWriteSeeker.Seek. -func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { - return f.FD.Seek(f.Ctx, offset, int32(whence)) -} - -// WriteAt implements io.WriterAt.WriteAt. -func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) - return int(n), err -} - -// Write implements io.ReadWriteSeeker.Write. -func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { - buf := usermem.BytesIOSequence(p) - n, err := f.FD.Write(f.Ctx, buf, f.WOpts) - return int(n), err +// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD +// returns EOF or an error. It returns the number of bytes copied. +func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) { + done := int64(0) + buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size + for { + readN, readErr := srcFD.Read(ctx, buf, ReadOptions{}) + if readErr != nil && readErr != io.EOF { + return done, readErr + } + src := buf.TakeFirst64(readN) + for src.NumBytes() != 0 { + writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{}) + done += writeN + src = src.DropFirst64(writeN) + if writeErr != nil { + return done, writeErr + } + } + if readErr == io.EOF { + return done, nil + } + } } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index c93d94634..2c4b81e78 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -48,7 +48,7 @@ type Filesystem struct { // Init must be called before first use of fs. func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) { - fs.EnableLeakCheck() + fs.InitRefs() fs.vfs = vfsObj fs.fsType = fsType fs.impl = impl diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 2d27d9d35..ba6e6ed49 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { return vfs.PrependPathAtVFSRootError{} } - if &d.vfsd == mnt.Root() { + if mnt != nil && &d.vfsd == mnt.Root() { return nil } if d.parent == nil { @@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath d = d.parent } } + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func DebugPathname(d *Dentry) string { + var b fspath.Builder + _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 3f0b8f45b..107171b61 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -65,7 +65,7 @@ type Inotify struct { // queue is used to notify interested parties when the inotify instance // becomes readable or writable. - queue waiter.Queue `state:"nosave"` + queue waiter.Queue // evMu *only* protects the events list. We need a separate lock while // queuing events: using mu may violate lock ordering, since at that point diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 55783d4eb..1ff202f2a 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package lock provides POSIX and BSD style file locking for VFS2 file -// implementations. -// -// The actual implementations can be found in the lock package under -// sentry/fs/lock. package vfs import ( diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 78f115bfa..d865fd603 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) @@ -106,6 +107,7 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount if opts.ReadOnly { mnt.setReadOnlyLocked(true) } + refsvfs2.Register(mnt) return mnt } @@ -167,7 +169,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth Owner: creds.UserNamespace, mountpoints: make(map[*Dentry]uint32), } - mntns.EnableLeakCheck() + mntns.InitRefs() mntns.root = newMount(vfs, fs, root, mntns, opts) return mntns, nil } @@ -470,11 +472,14 @@ func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { // tryIncMountedRef does not require that a reference is held on mnt. func (mnt *Mount) tryIncMountedRef() bool { for { - refs := atomic.LoadInt64(&mnt.refs) - if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted + r := atomic.LoadInt64(&mnt.refs) + if r <= 0 { // r < 0 => MSB set => eagerly unmounted return false } - if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&mnt.refs, r, r+1) { + if mnt.LogRefs() { + refsvfs2.LogTryIncRef(mnt, r+1) + } return true } } @@ -484,29 +489,58 @@ func (mnt *Mount) tryIncMountedRef() bool { func (mnt *Mount) IncRef() { // In general, negative values for mnt.refs are valid because the MSB is // the eager-unmount bit. - atomic.AddInt64(&mnt.refs, 1) + r := atomic.AddInt64(&mnt.refs, 1) + if mnt.LogRefs() { + refsvfs2.LogIncRef(mnt, r) + } } // DecRef decrements mnt's reference count. func (mnt *Mount) DecRef(ctx context.Context) { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - var vd VirtualDentry - if mnt.parent() != nil { - mnt.vfs.mountMu.Lock() - mnt.vfs.mounts.seq.BeginWrite() - vd = mnt.vfs.disconnectLocked(mnt) - mnt.vfs.mounts.seq.EndWrite() - mnt.vfs.mountMu.Unlock() - } - if mnt.root != nil { - mnt.root.DecRef(ctx) - } - mnt.fs.DecRef(ctx) - if vd.Ok() { - vd.DecRef(ctx) - } + r := atomic.AddInt64(&mnt.refs, -1) + if mnt.LogRefs() { + refsvfs2.LogDecRef(mnt, r) + } + if r&^math.MinInt64 == 0 { // mask out MSB + refsvfs2.Unregister(mnt) + mnt.destroy(ctx) + } +} + +func (mnt *Mount) destroy(ctx context.Context) { + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() } + if mnt.root != nil { + mnt.root.DecRef(ctx) + } + mnt.fs.DecRef(ctx) + if vd.Ok() { + vd.DecRef(ctx) + } +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (mnt *Mount) RefType() string { + return "vfs.Mount" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (mnt *Mount) LeakMessage() string { + return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (mnt *Mount) LogRefs() bool { + return false } // DecRef decrements mntns' reference count. diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index cb8c56bd3..cb882a983 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -29,7 +29,7 @@ func TestMountTableLookupEmpty(t *testing.T) { parent := &Mount{} point := &Dentry{} if m := mt.Lookup(parent, point); m != nil { - t.Errorf("empty mountTable lookup: got %p, wanted nil", m) + t.Errorf("Empty mountTable lookup: got %p, wanted nil", m) } } @@ -111,13 +111,16 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -167,13 +170,16 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -220,14 +226,17 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] mi, ok := ms.Load(k) if !ok { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } m := mi.(*Mount) if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -264,7 +273,7 @@ func BenchmarkMountTableNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -300,7 +309,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -333,7 +342,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m, _ := ms.Load(k) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b7d122d22..0df023713 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package vfs import ( @@ -41,6 +36,15 @@ type mountKey struct { point unsafe.Pointer // *Dentry } +var ( + mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil)) + mountKeySeed = sync.RandUintptr() +) + +func (k *mountKey) hash() uintptr { + return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed) +} + func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry { } } -func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } - // Invariant: mnt.key.parent == nil. vd.Ok(). func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } -func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } - // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). // // mountTable.Init() must be called on new mountTables before use. -// -// +stateify savable type mountTable struct { // mountTable is implemented as a seqcount-protected hash table that // resolves collisions with linear probing, featuring Robin Hood insertion @@ -84,8 +82,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq sync.SeqCount `state:"nosave"` - seed uint32 // for hashing keys + seq sync.SeqCount `state:"nosave"` // size holds both length (number of elements) and capacity (number of // slots): capacity is stored as its base-2 log (referred to as order) in @@ -98,7 +95,6 @@ type mountTable struct { // length and cap in separate uint32s) for ~free. size uint64 - // FIXME(gvisor.dev/issue/1663): Slots need to be saved. slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init } @@ -151,7 +147,6 @@ func init() { // Init must be called exactly once on each mountTable before use. func (mt *mountTable) Init() { - mt.seed = rand32() mt.size = mtInitOrder mt.slots = newMountTableSlots(mtInitCap) } @@ -168,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := key.hash() loop: for { @@ -212,6 +207,26 @@ loop: } } +// Range calls f on each Mount in mt. If f returns false, Range stops iteration +// and returns immediately. +func (mt *mountTable) Range(f func(*Mount) bool) { + tcap := uintptr(1) << (mt.size & mtSizeOrderMask) + slotPtr := mt.slots + last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes)) + for { + slot := (*mountSlot)(slotPtr) + if slot.value != nil { + if !f((*Mount)(slot.value)) { + return + } + } + if slotPtr == last { + return + } + slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes) + } +} + // Insert inserts the given mount into mt. // // Preconditions: mt must not already contain a Mount with the same mount point @@ -228,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() // We're under the maximum load factor if: // @@ -327,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 slots := mt.slots @@ -367,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) { off = (off + mountSlotBytes) & offmask } } - -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, seed, s uintptr) uintptr - -//go:linkname rand32 runtime.fastrand -func rand32() uint32 diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go new file mode 100644 index 000000000..8998a82dd --- /dev/null +++ b/pkg/sentry/vfs/save_restore.go @@ -0,0 +1,147 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/waiter" +) + +// FilesystemImplSaveRestoreExtension is an optional extension to +// FilesystemImpl. +type FilesystemImplSaveRestoreExtension interface { + // PrepareSave prepares this filesystem for serialization. + PrepareSave(ctx context.Context) error + + // CompleteRestore completes restoration from checkpoint for this + // filesystem after deserialization. + CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error +} + +// PrepareSave prepares all filesystems for serialization. +func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.PrepareSave(ctx); err != nil { + ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to prepare for serialization", failures) + } + return nil +} + +// CompleteRestore completes restoration from checkpoint for all filesystems +// after deserialization. +func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.CompleteRestore(ctx, *opts); err != nil { + ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures) + } + return nil +} + +// CompleteRestoreOptions contains options to +// VirtualFilesystem.CompleteRestore() and +// FilesystemImplSaveRestoreExtension.CompleteRestore(). +type CompleteRestoreOptions struct { + // If ValidateFileSizes is true, filesystem implementations backed by + // remote filesystems should verify that file sizes have not changed + // between checkpoint and restore. + ValidateFileSizes bool + + // If ValidateFileModificationTimestamps is true, filesystem + // implementations backed by remote filesystems should validate that file + // mtimes have not changed between checkpoint and restore. + ValidateFileModificationTimestamps bool +} + +// saveMounts is called by stateify. +func (vfs *VirtualFilesystem) saveMounts() []*Mount { + if atomic.LoadPointer(&vfs.mounts.slots) == nil { + // vfs.Init() was never called. + return nil + } + var mounts []*Mount + vfs.mounts.Range(func(mount *Mount) bool { + mounts = append(mounts, mount) + return true + }) + return mounts +} + +// saveKey is called by stateify. +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + +// loadMounts is called by stateify. +func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { + if mounts == nil { + return + } + vfs.mounts.Init() + for _, mount := range mounts { + vfs.mounts.Insert(mount) + } +} + +// loadKey is called by stateify. +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + +func (mnt *Mount) afterLoad() { + if atomic.LoadInt64(&mnt.refs) != 0 { + refsvfs2.Register(mnt) + } +} + +// afterLoad is called by stateify. +func (epi *epollInterest) afterLoad() { + // Mark all epollInterests as ready after restore so that the next call to + // EpollInstance.ReadEvents() rechecks their readiness. + epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask)) +} + +// beforeSave is called by stateify. +func (fd *FileDescription) beforeSave() { + fd.saved = true + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Unregister(fd) + } +} + +// afterLoad is called by stateify. +func (fd *FileDescription) afterLoad() { + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Register(fd) + } +} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 38d2701d2..6fd1bb0b2 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -41,6 +41,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" @@ -71,7 +72,7 @@ type VirtualFilesystem struct { // points. // // mounts is analogous to Linux's mount_hashtable. - mounts mountTable + mounts mountTable `state:".([]*Mount)"` // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. @@ -381,6 +382,8 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia // OpenAt returns a FileDescription providing access to the file at the given // path. A reference is taken on the returned FileDescription. func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { + fsmetric.Opens.Increment() + // Remove: // // - O_CLOEXEC, which affects file descriptors and therefore must be @@ -780,23 +783,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre // SyncAllFilesystems has the semantics of Linux's sync(2). func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + var retErr error + for fs := range vfs.getFilesystems() { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef(ctx) + } + return retErr +} + +func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} { fss := make(map[*Filesystem]struct{}) vfs.filesystemsMu.Lock() + defer vfs.filesystemsMu.Unlock() for fs := range vfs.filesystems { if !fs.TryIncRef() { continue } fss[fs] = struct{}{} } - vfs.filesystemsMu.Unlock() - var retErr error - for fs := range fss { - if err := fs.impl.Sync(ctx); err != nil && retErr == nil { - retErr = err - } - fs.DecRef(ctx) - } - return retErr + return fss } // MkdirAllAt recursively creates non-existent directories on the given path diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index bbafb8b7f..8e3146d8d 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -336,10 +336,9 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo buf.WriteString(fmt.Sprintf("Sentry detected %d stuck task(s):\n", len(offenders))) for t, o := range offenders { tid := w.k.TaskSet().Root.IDOfTask(t) - buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime))) + buf.WriteString(fmt.Sprintf("\tTask tid: %v (goroutine %d), entered RunSys state %v ago.\n", tid, t.GoroutineID(), now.Sub(o.lastUpdateTime))) } - - buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") + buf.WriteString("Search for 'goroutine <id>' in the stack dump to find the offending goroutine(s)") // Force stack dump only if a new task is detected. w.doAction(w.TaskTimeoutAction, newTaskFound, &buf) diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD index f08599ebd..cb0001852 100644 --- a/pkg/shim/runsc/BUILD +++ b/pkg/shim/runsc/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "@com_github_containerd_containerd//log:go_default_library", "@com_github_containerd_go_runc//:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go index c5cf68efa..aedaf5ee5 100644 --- a/pkg/shim/runsc/runsc.go +++ b/pkg/shim/runsc/runsc.go @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package runsc provides an API to interact with runsc command line. package runsc import ( @@ -28,15 +29,37 @@ import ( "syscall" "time" + "github.com/containerd/containerd/log" runc "github.com/containerd/go-runc" specs "github.com/opencontainers/runtime-spec/specs-go" ) -var Monitor runc.ProcessMonitor = runc.Monitor - // DefaultCommand is the default command for Runsc. const DefaultCommand = "runsc" +// Monitor is the default process monitor to be used by runsc. +var Monitor runc.ProcessMonitor = &LogMonitor{Next: runc.Monitor} + +// LogMonitor implements the runc.ProcessMonitor interface, logging the command +// that is getting executed, and then forwarding the call to another +// implementation. +type LogMonitor struct { + Next runc.ProcessMonitor +} + +// Start implements runc.ProcessMonitor. +func (l *LogMonitor) Start(cmd *exec.Cmd) (chan runc.Exit, error) { + log.L.Debugf("Executing: %s", cmd.Args) + return l.Next.Start(cmd) +} + +// Wait implements runc.ProcessMonitor. +func (l *LogMonitor) Wait(cmd *exec.Cmd, ch chan runc.Exit) (int, error) { + status, err := l.Next.Wait(cmd, ch) + log.L.Debugf("Command exit code: %d, err: %v", status, err) + return status, err +} + // Runsc is the client to the runsc cli. type Runsc struct { Command string @@ -74,6 +97,7 @@ func (r *Runsc) State(context context.Context, id string) (*runc.Container, erro return &c, nil } +// CreateOpts is a set of options to Runsc.Create(). type CreateOpts struct { runc.IO ConsoleSocket runc.ConsoleSocket @@ -197,6 +221,7 @@ func (r *Runsc) Wait(context context.Context, id string) (int, error) { return res.ExitStatus, nil } +// ExecOpts is a set of options to runsc.Exec(). type ExecOpts struct { runc.IO PidFile string @@ -301,6 +326,7 @@ func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts return Monitor.Wait(cmd, ec) } +// DeleteOpts is a set of options to runsc.Delete(). type DeleteOpts struct { Force bool } @@ -365,8 +391,16 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { }() var e runc.Event if err := json.NewDecoder(rd).Decode(&e); err != nil { + log.L.Debugf("Parsing events error: %v", err) return nil, err } + log.L.Debugf("Stats returned, type: %s, stats: %+v", e.Type, e.Stats) + if e.Type != "stats" { + return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type) + } + if e.Stats == nil { + return nil, fmt.Errorf(`"runsc events -stat" succeeded but no stat was provided`) + } return e.Stats, nil } diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go index c514b3bc7..55f17d29e 100644 --- a/pkg/shim/runsc/utils.go +++ b/pkg/shim/runsc/utils.go @@ -36,9 +36,20 @@ func putBuf(b *bytes.Buffer) { bytesBufferPool.Put(b) } -// FormatLogPath parses runsc config, and fill in %ID% in the log path. -func FormatLogPath(id string, config map[string]string) { +// FormatRunscLogPath parses runsc config, and fill in %ID% in the log path. +func FormatRunscLogPath(id string, config map[string]string) { if path, ok := config["debug-log"]; ok { config["debug-log"] = strings.Replace(path, "%ID%", id, -1) } } + +// FormatShimLogPath creates the file path to the log file. It replaces %ID% +// in the path with the provided "id". It also uses a default log name if the +// path end with '/'. +func FormatShimLogPath(path string, id string) string { + if strings.HasSuffix(path, "/") { + // Default format: <path>/runsc-shim-<ID>.log + path += "runsc-shim-%ID%.log" + } + return strings.Replace(path, "%ID%", id, -1) +} diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go index dab3123d6..9fd7d978c 100644 --- a/pkg/shim/v1/proc/init.go +++ b/pkg/shim/v1/proc/init.go @@ -397,7 +397,7 @@ func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Pr } // exec returns a new exec'd process. -func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { +func (p *Init) exec(path string, r *ExecConfig) (process.Process, error) { // process exec request var spec specs.Process if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go index 9233ecc85..0065fc385 100644 --- a/pkg/shim/v1/proc/init_state.go +++ b/pkg/shim/v1/proc/init_state.go @@ -95,7 +95,7 @@ func (s *createdState) SetExited(status int) { } func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { - return s.p.exec(ctx, path, r) + return s.p.exec(path, r) } type runningState struct { @@ -137,7 +137,7 @@ func (s *runningState) SetExited(status int) { } func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { - return s.p.exec(ctx, path, r) + return s.p.exec(path, r) } type stoppedState struct { diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go index d462c3eef..e8315326d 100644 --- a/pkg/shim/v1/proc/process.go +++ b/pkg/shim/v1/proc/process.go @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package proc contains process-related utilities. package proc import ( diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go index 2b0df4663..fc182cf5e 100644 --- a/pkg/shim/v1/proc/types.go +++ b/pkg/shim/v1/proc/types.go @@ -40,7 +40,6 @@ type CreateConfig struct { Stdin string Stdout string Stderr string - Options *types.Any } // ExecConfig holds exec creation configuration. diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go index 716de2f59..7c2c409af 100644 --- a/pkg/shim/v1/proc/utils.go +++ b/pkg/shim/v1/proc/utils.go @@ -67,24 +67,6 @@ func getLastRuntimeError(r *runsc.Runsc) (string, error) { return errMsg, nil } -func copyFile(to, from string) error { - ff, err := os.Open(from) - if err != nil { - return err - } - defer ff.Close() - tt, err := os.Create(to) - if err != nil { - return err - } - defer tt.Close() - - p := bufPool.Get().(*[]byte) - defer bufPool.Put(p) - _, err = io.CopyBuffer(tt, ff, *p) - return err -} - func hasNoIO(r *CreateConfig) bool { return r.Stdin == "" && r.Stdout == "" && r.Stderr == "" } diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD index 05c595bc9..e5b6bf186 100644 --- a/pkg/shim/v1/shim/BUILD +++ b/pkg/shim/v1/shim/BUILD @@ -8,6 +8,7 @@ go_library( "api.go", "platform.go", "service.go", + "shim.go", ], visibility = [ "//pkg/shim:__subpackages__", diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go index 84a810cb2..80aa59b33 100644 --- a/pkg/shim/v1/shim/service.go +++ b/pkg/shim/v1/shim/service.go @@ -130,7 +130,6 @@ func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shi Stdin: r.Stdin, Stdout: r.Stdout, Stderr: r.Stderr, - Options: r.Options, } defer func() { if err != nil { @@ -150,7 +149,6 @@ func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shi } } process, err := newInit( - ctx, s.config.Path, s.config.WorkDir, s.config.RuntimeRoot, @@ -158,6 +156,7 @@ func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shi s.config.RunscConfig, s.platform, config, + r.Options, ) if err := process.Create(ctx, config); err != nil { return nil, errdefs.ToGRPC(err) @@ -533,14 +532,14 @@ func getTopic(ctx context.Context, e interface{}) string { return runtime.TaskUnknownTopic } -func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) { - var options runctypes.CreateOptions - if r.Options != nil { - v, err := typeurl.UnmarshalAny(r.Options) +func newInit(path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig, options *types.Any) (*proc.Init, error) { + var opts runctypes.CreateOptions + if options != nil { + v, err := typeurl.UnmarshalAny(options) if err != nil { return nil, err } - options = *v.(*runctypes.CreateOptions) + opts = *v.(*runctypes.CreateOptions) } spec, err := utils.ReadSpec(r.Bundle) @@ -551,7 +550,7 @@ func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, return nil, fmt.Errorf("update volume annotations: %w", err) } - runsc.FormatLogPath(r.ID, config) + runsc.FormatRunscLogPath(r.ID, config) rootfs := filepath.Join(path, "rootfs") runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config) p := proc.New(r.ID, runtime, stdio.Stdio{ @@ -564,8 +563,8 @@ func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, p.Platform = platform p.Rootfs = rootfs p.WorkDir = workDir - p.IoUID = int(options.IoUid) - p.IoGID = int(options.IoGid) + p.IoUID = int(opts.IoUid) + p.IoGID = int(opts.IoGid) p.Sandbox = utils.IsSandbox(spec) p.UserLog = utils.UserLogPath(spec) p.Monitor = reaper.Default diff --git a/pkg/shim/v1/shim/shim.go b/pkg/shim/v1/shim/shim.go new file mode 100644 index 000000000..1855a8769 --- /dev/null +++ b/pkg/shim/v1/shim/shim.go @@ -0,0 +1,17 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package shim contains the core containerd shim implementation. +package shim diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go index 07e346654..21e75d16d 100644 --- a/pkg/shim/v1/utils/utils.go +++ b/pkg/shim/v1/utils/utils.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package utils contains utility functions. package utils import ( diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD index 7e0a114a0..b0e8daa51 100644 --- a/pkg/shim/v2/BUILD +++ b/pkg/shim/v2/BUILD @@ -7,19 +7,22 @@ go_library( srcs = [ "api.go", "epoll.go", + "options.go", "service.go", "service_linux.go", + "state.go", ], visibility = ["//shim:__subpackages__"], deps = [ + "//pkg/cleanup", "//pkg/shim/runsc", "//pkg/shim/v1/proc", "//pkg/shim/v1/utils", - "//pkg/shim/v2/options", "//pkg/shim/v2/runtimeoptions", "//runsc/specutils", "@com_github_burntsushi_toml//:go_default_library", "@com_github_containerd_cgroups//:go_default_library", + "@com_github_containerd_cgroups//stats/v1:go_default_library", "@com_github_containerd_console//:go_default_library", "@com_github_containerd_containerd//api/events:go_default_library", "@com_github_containerd_containerd//api/types/task:go_default_library", @@ -38,6 +41,7 @@ go_library( "@com_github_containerd_fifo//:go_default_library", "@com_github_containerd_typeurl//:go_default_library", "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/shim/v2/options.go b/pkg/shim/v2/options.go new file mode 100644 index 000000000..9db33fd1f --- /dev/null +++ b/pkg/shim/v2/options.go @@ -0,0 +1,50 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +const optionsType = "io.containerd.runsc.v1.options" + +// options is runtime options for io.containerd.runsc.v1. +type options struct { + // ShimCgroup is the cgroup the shim should be in. + ShimCgroup string `toml:"shim_cgroup" json:"shimCgroup"` + + // IoUID is the I/O's pipes uid. + IoUID uint32 `toml:"io_uid" json:"ioUid"` + + // IoGID is the I/O's pipes gid. + IoGID uint32 `toml:"io_gid" json:"ioGid"` + + // BinaryName is the binary name of the runsc binary. + BinaryName string `toml:"binary_name" json:"binaryName"` + + // Root is the runsc root directory. + Root string `toml:"root" json:"root"` + + // LogLevel sets the logging level. Some of the possible values are: debug, + // info, warning. + // + // This configuration only applies when the shim is running as a service. + LogLevel string `toml:"log_level" json:"logLevel"` + + // LogPath is the path to log directory. %ID% tags inside the string are + // replaced with the container ID. + // + // This configuration only applies when the shim is running as a service. + LogPath string `toml:"log_path" json:"logPath"` + + // RunscConfig is a key/value map of all runsc flags. + RunscConfig map[string]string `toml:"runsc_config" json:"runscConfig"` +} diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go deleted file mode 100644 index de09f2f79..000000000 --- a/pkg/shim/v2/options/options.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package options - -const OptionType = "io.containerd.runsc.v1.options" - -// Options is runtime options for io.containerd.runsc.v1. -type Options struct { - // ShimCgroup is the cgroup the shim should be in. - ShimCgroup string `toml:"shim_cgroup"` - // IoUid is the I/O's pipes uid. - IoUid uint32 `toml:"io_uid"` - // IoUid is the I/O's pipes gid. - IoGid uint32 `toml:"io_gid"` - // BinaryName is the binary name of the runsc binary. - BinaryName string `toml:"binary_name"` - // Root is the runsc root directory. - Root string `toml:"root"` - // RunscConfig is a key/value map of all runsc flags. - RunscConfig map[string]string `toml:"runsc_config"` -} diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go index 1534152fc..6aaf5fab8 100644 --- a/pkg/shim/v2/service.go +++ b/pkg/shim/v2/service.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package v2 implements Containerd Shim v2 interface. package v2 import ( "context" "fmt" - "io/ioutil" + "io" "os" "os/exec" "path/filepath" @@ -27,6 +28,7 @@ import ( "github.com/BurntSushi/toml" "github.com/containerd/cgroups" + cgroupsstats "github.com/containerd/cgroups/stats/v1" "github.com/containerd/console" "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/api/types/task" @@ -43,12 +45,13 @@ import ( "github.com/containerd/containerd/sys/reaper" "github.com/containerd/typeurl" "github.com/gogo/protobuf/types" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/shim/runsc" "gvisor.dev/gvisor/pkg/shim/v1/proc" "gvisor.dev/gvisor/pkg/shim/v1/utils" - "gvisor.dev/gvisor/pkg/shim/v2/options" "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions" "gvisor.dev/gvisor/runsc/specutils" ) @@ -65,55 +68,108 @@ var ( var _ = (taskAPI.TaskService)(&service{}) -// configFile is the default config file name. For containerd 1.2, -// we assume that a config.toml should exist in the runtime root. -const configFile = "config.toml" +const ( + // configFile is the default config file name. For containerd 1.2, + // we assume that a config.toml should exist in the runtime root. + configFile = "config.toml" + + // shimAddressPath is the relative path to a file that contains the address + // to the shim UDS. See service.shimAddress. + shimAddressPath = "address" +) // New returns a new shim service that can be used via GRPC. func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { + log.L.Debugf("service.New, id: %s", id) + + var opts shim.Opts + if ctxOpts := ctx.Value(shim.OptsKey{}); ctxOpts != nil { + opts = ctxOpts.(shim.Opts) + } + ep, err := newOOMEpoller(publisher) if err != nil { return nil, err } go ep.run(ctx) s := &service{ - id: id, - context: ctx, - processes: make(map[string]process.Process), - events: make(chan interface{}, 128), - ec: proc.ExitCh, - oomPoller: ep, - cancel: cancel, - } - go s.processExits() - runsc.Monitor = reaper.Default + id: id, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + oomPoller: ep, + cancel: cancel, + genericOptions: opts, + } + go s.processExits(ctx) + runsc.Monitor = &runsc.LogMonitor{Next: reaper.Default} if err := s.initPlatform(); err != nil { cancel() return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) } - go s.forward(publisher) + go s.forward(ctx, publisher) + + if address, err := shim.ReadAddress(shimAddressPath); err == nil { + s.shimAddress = address + } + return s, nil } -// service is the shim implementation of a remote shim over GRPC. +// service is the shim implementation of a remote shim over GRPC. It runs in 2 +// different modes: +// 1. Service: process runs for the life time of the container and receives +// calls described in shimapi.TaskService interface. +// 2. Tool: process is short lived and runs only to perform the requested +// operations and then exits. It implements the direct functions in +// shim.Shim interface. +// +// When the service is running, it saves a json file with state information so +// that commands sent to the tool can load the state and perform the operation. type service struct { mu sync.Mutex - context context.Context - task process.Process + // id is the container ID. + id string + + // bundle is a path provided by the caller on container creation. Store + // because it's needed in commands that don't receive bundle in the request. + bundle string + + // task is the main process that is running the container. + task *proc.Init + + // processes maps ExecId to processes running through exec. processes map[string]process.Process - events chan interface{} - platform stdio.Platform - opts options.Options - ec chan proc.Exit + + events chan interface{} + + // platform handles operations related to the console. + platform stdio.Platform + + // genericOptions are options that come from the shim interface and are common + // to all shims. + genericOptions shim.Opts + + // opts are configuration options specific for this shim. + opts options + + // ex gets notified whenever the container init process or an exec'd process + // exits from inside the sandbox. + ec chan proc.Exit + + // oomPoller monitors the sandbox's cgroup for OOM notifications. oomPoller *epoller - id string - bundle string + // cancel is a function that needs to be called before the shim stops. The + // function is provided by the caller to New(). cancel func() + + // shimAddress is the location of the UDS used to communicate to containerd. + shimAddress string } -func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { +func (s *service) newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { ns, err := namespaces.NamespaceRequired(ctx) if err != nil { return nil, err @@ -131,6 +187,9 @@ func newCommand(ctx context.Context, containerdBinary, containerdAddress string) "-address", containerdAddress, "-publish-binary", containerdBinary, } + if s.genericOptions.Debug { + args = append(args, "-debug") + } cmd := exec.Command(self, args...) cmd.Dir = cwd cmd.Env = append(os.Environ(), "GOMAXPROCS=2") @@ -141,50 +200,78 @@ func newCommand(ctx context.Context, containerdBinary, containerdAddress string) } func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) { - cmd, err := newCommand(ctx, containerdBinary, containerdAddress) + log.L.Debugf("StartShim, id: %s, binary: %q, address: %q", id, containerdBinary, containerdAddress) + + cmd, err := s.newCommand(ctx, containerdBinary, containerdAddress) if err != nil { return "", err } - address, err := shim.SocketAddress(ctx, id) + address, err := shim.SocketAddress(ctx, containerdAddress, id) if err != nil { return "", err } socket, err := shim.NewSocket(address) if err != nil { - return "", err + // The only time where this would happen is if there is a bug and the socket + // was not cleaned up in the cleanup method of the shim or we are using the + // grouping functionality where the new process should be run with the same + // shim as an existing container. + if !shim.SocketEaddrinuse(err) { + return "", fmt.Errorf("create new shim socket: %w", err) + } + if shim.CanConnect(address) { + if err := shim.WriteAddress(shimAddressPath, address); err != nil { + return "", fmt.Errorf("write existing socket for shim: %w", err) + } + return address, nil + } + if err := shim.RemoveSocket(address); err != nil { + return "", fmt.Errorf("remove pre-existing socket: %w", err) + } + if socket, err = shim.NewSocket(address); err != nil { + return "", fmt.Errorf("try create new shim socket 2x: %w", err) + } } - defer socket.Close() + cu := cleanup.Make(func() { + socket.Close() + _ = shim.RemoveSocket(address) + }) + defer cu.Clean() + f, err := socket.File() if err != nil { return "", err } - defer f.Close() cmd.ExtraFiles = append(cmd.ExtraFiles, f) + log.L.Debugf("Executing: %q %s", cmd.Path, cmd.Args) if err := cmd.Start(); err != nil { + f.Close() return "", err } - defer func() { - if err != nil { - cmd.Process.Kill() - } - }() + cu.Add(func() { cmd.Process.Kill() }) + // make sure to wait after start go cmd.Wait() if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil { return "", err } - if err := shim.WriteAddress("address", address); err != nil { + if err := shim.WriteAddress(shimAddressPath, address); err != nil { return "", err } if err := shim.SetScore(cmd.Process.Pid); err != nil { return "", fmt.Errorf("failed to set OOM Score on shim: %w", err) } + cu.Release() return address, nil } +// Cleanup is called from another process (need to reload state) to stop the +// container and undo all operations done in Create(). func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) { + log.L.Debugf("Cleanup") + path, err := os.Getwd() if err != nil { return nil, err @@ -193,18 +280,19 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) if err != nil { return nil, err } - runtime, err := s.readRuntime(path) - if err != nil { + var st state + if err := st.load(path); err != nil { return nil, err } - r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + r := proc.NewRunsc(s.opts.Root, path, ns, st.Options.BinaryName, nil) + if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{ Force: true, }); err != nil { - log.L.Printf("failed to remove runc container: %v", err) + log.L.Infof("failed to remove runc container: %v", err) } - if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil { - log.L.Printf("failed to cleanup rootfs mount: %v", err) + if err := mount.UnmountAll(st.Rootfs, 0); err != nil { + log.L.Infof("failed to cleanup rootfs mount: %v", err) } return &taskAPI.DeleteResponse{ ExitedAt: time.Now(), @@ -212,31 +300,24 @@ func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) }, nil } -func (s *service) readRuntime(path string) (string, error) { - data, err := ioutil.ReadFile(filepath.Join(path, "runtime")) - if err != nil { - return "", err - } - return string(data), nil -} - -func (s *service) writeRuntime(path, runtime string) error { - return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600) -} - // Create creates a new initial process and container with the underlying OCI // runtime. -func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) { +func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (*taskAPI.CreateTaskResponse, error) { + log.L.Debugf("Create, id: %s, bundle: %q", r.ID, r.Bundle) + s.mu.Lock() defer s.mu.Unlock() + // Save the main task id and bundle to the shim for additional requests. + s.id = r.ID + s.bundle = r.Bundle + ns, err := namespaces.NamespaceRequired(ctx) if err != nil { return nil, fmt.Errorf("create namespace: %w", err) } // Read from root for now. - var opts options.Options if r.Options != nil { v, err := typeurl.UnmarshalAny(r.Options) if err != nil { @@ -245,16 +326,16 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * var path string switch o := v.(type) { case *runctypes.CreateOptions: // containerd 1.2.x - opts.IoUid = o.IoUid - opts.IoGid = o.IoGid - opts.ShimCgroup = o.ShimCgroup + s.opts.IoUID = o.IoUid + s.opts.IoGID = o.IoGid + s.opts.ShimCgroup = o.ShimCgroup case *runctypes.RuncOptions: // containerd 1.2.x root := proc.RunscRoot if o.RuntimeRoot != "" { root = o.RuntimeRoot } - opts.BinaryName = o.Runtime + s.opts.BinaryName = o.Runtime path = filepath.Join(root, configFile) if _, err := os.Stat(path); err != nil { @@ -268,7 +349,7 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * if o.ConfigPath == "" { break } - if o.TypeUrl != options.OptionType { + if o.TypeUrl != optionsType { return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl) } path = o.ConfigPath @@ -276,12 +357,61 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl) } if path != "" { - if _, err = toml.DecodeFile(path, &opts); err != nil { + if _, err = toml.DecodeFile(path, &s.opts); err != nil { return nil, fmt.Errorf("decode config file %q: %w", path, err) } } } + if len(s.opts.LogLevel) != 0 { + lvl, err := logrus.ParseLevel(s.opts.LogLevel) + if err != nil { + return nil, err + } + logrus.SetLevel(lvl) + } + if len(s.opts.LogPath) != 0 { + logPath := runsc.FormatShimLogPath(s.opts.LogPath, s.id) + if err := os.MkdirAll(filepath.Dir(logPath), 0777); err != nil { + return nil, fmt.Errorf("failed to create log dir: %w", err) + } + logFile, err := os.Create(logPath) + if err != nil { + return nil, fmt.Errorf("failed to create log file: %w", err) + } + log.L.Debugf("Starting mirror log at %q", logPath) + std := logrus.StandardLogger() + std.SetOutput(io.MultiWriter(std.Out, logFile)) + + log.L.Debugf("Create shim") + log.L.Debugf("***************************") + log.L.Debugf("Args: %s", os.Args) + log.L.Debugf("PID: %d", os.Getpid()) + log.L.Debugf("ID: %s", s.id) + log.L.Debugf("Options: %+v", s.opts) + log.L.Debugf("Bundle: %s", r.Bundle) + log.L.Debugf("Terminal: %t", r.Terminal) + log.L.Debugf("stdin: %s", r.Stdin) + log.L.Debugf("stdout: %s", r.Stdout) + log.L.Debugf("stderr: %s", r.Stderr) + log.L.Debugf("***************************") + } + + // Save state before any action is taken to ensure Cleanup() will have all + // the information it needs to undo the operations. + st := state{ + Rootfs: filepath.Join(r.Bundle, "rootfs"), + Options: s.opts, + } + if err := st.save(r.Bundle); err != nil { + return nil, err + } + + if err := os.Mkdir(st.Rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + // Convert from types.Mount to proc.Mount. var mounts []proc.Mount for _, m := range r.Rootfs { mounts = append(mounts, proc.Mount{ @@ -292,62 +422,41 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * }) } - rootfs := filepath.Join(r.Bundle, "rootfs") - if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { - return nil, err + // Cleans up all mounts in case of failure. + cu := cleanup.Make(func() { + if err := mount.UnmountAll(st.Rootfs, 0); err != nil { + log.L.Infof("failed to cleanup rootfs mount: %v", err) + } + }) + defer cu.Clean() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(st.Rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } } config := &proc.CreateConfig{ ID: r.ID, Bundle: r.Bundle, - Runtime: opts.BinaryName, + Runtime: s.opts.BinaryName, Rootfs: mounts, Terminal: r.Terminal, Stdin: r.Stdin, Stdout: r.Stdout, Stderr: r.Stderr, - Options: r.Options, - } - if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil { - return nil, err } - defer func() { - if err != nil { - if err := mount.UnmountAll(rootfs, 0); err != nil { - log.L.Printf("failed to cleanup rootfs mount: %v", err) - } - } - }() - for _, rm := range mounts { - m := &mount.Mount{ - Type: rm.Type, - Source: rm.Source, - Options: rm.Options, - } - if err := m.Mount(rootfs); err != nil { - return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) - } - } - process, err := newInit( - ctx, - r.Bundle, - filepath.Join(r.Bundle, "work"), - ns, - s.platform, - config, - &opts, - rootfs, - ) + process, err := newInit(r.Bundle, filepath.Join(r.Bundle, "work"), ns, s.platform, config, &s.opts, st.Rootfs) if err != nil { return nil, errdefs.ToGRPC(err) } if err := process.Create(ctx, config); err != nil { return nil, errdefs.ToGRPC(err) } - // Save the main task id and bundle to the shim for additional - // requests. - s.id = r.ID - s.bundle = r.Bundle // Set up OOM notification on the sandbox's cgroup. This is done on // sandbox create since the sandbox process will be created here. @@ -361,16 +470,19 @@ func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ * return nil, fmt.Errorf("add cg to OOM monitor: %w", err) } } + + // Success + cu.Release() s.task = process - s.opts = opts return &taskAPI.CreateTaskResponse{ Pid: uint32(process.Pid()), }, nil - } // Start starts a process. func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) { + log.L.Debugf("Start, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -387,6 +499,8 @@ func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI. // Delete deletes the initial process and container. func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) { + log.L.Debugf("Delete, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -397,13 +511,11 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP if err := p.Delete(ctx); err != nil { return nil, err } - isTask := r.ExecID == "" - if !isTask { + if len(r.ExecID) != 0 { s.mu.Lock() delete(s.processes, r.ExecID) s.mu.Unlock() - } - if isTask && s.platform != nil { + } else if s.platform != nil { s.platform.Close() } return &taskAPI.DeleteResponse{ @@ -415,17 +527,18 @@ func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAP // Exec spawns an additional process inside the container. func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) { + log.L.Debugf("Exec, id: %s, execID: %s", r.ID, r.ExecID) + s.mu.Lock() p := s.processes[r.ExecID] s.mu.Unlock() if p != nil { return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) } - p = s.task - if p == nil { + if s.task == nil { return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{ + process, err := s.task.Exec(ctx, s.bundle, &proc.ExecConfig{ ID: r.ExecID, Terminal: r.Terminal, Stdin: r.Stdin, @@ -444,6 +557,8 @@ func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*typ // ResizePty resizes the terminal of a process. func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) { + log.L.Debugf("ResizePty, id: %s, execID: %s, dimension: %dx%d", r.ID, r.ExecID, r.Height, r.Width) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -460,6 +575,8 @@ func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (* // State returns runtime state information for a process. func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) { + log.L.Debugf("State, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -494,16 +611,20 @@ func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI. // Pause the container. func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) { + log.L.Debugf("Pause, id: %s", r.ID) return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) } // Resume the container. func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) { + log.L.Debugf("Resume, id: %s", r.ID) return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) } // Kill a process with the provided signal. func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) { + log.L.Debugf("Kill, id: %s, execID: %s, signal: %d, all: %t", r.ID, r.ExecID, r.Signal, r.All) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -519,6 +640,8 @@ func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empt // Pids returns all pids inside the container. func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) { + log.L.Debugf("Pids, id: %s", r.ID) + pids, err := s.getContainerPids(ctx, r.ID) if err != nil { return nil, errdefs.ToGRPC(err) @@ -550,6 +673,8 @@ func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.Pi // CloseIO closes the I/O context of a process. func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) { + log.L.Debugf("CloseIO, id: %s, execID: %s, stdin: %t", r.ID, r.ExecID, r.Stdin) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -564,11 +689,14 @@ func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*type // Checkpoint checkpoints the container. func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) { + log.L.Debugf("Checkpoint, id: %s", r.ID) return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) } // Connect returns shim information such as the shim's pid. func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) { + log.L.Debugf("Connect, id: %s", r.ID) + var pid int if s.task != nil { pid = s.task.Pid() @@ -580,27 +708,24 @@ func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*task } func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) { + log.L.Debugf("Shutdown, id: %s", r.ID) s.cancel() + if s.shimAddress != "" { + _ = shim.RemoveSocket(s.shimAddress) + } os.Exit(0) - return empty, nil + panic("Should not get here") } func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { - path, err := os.Getwd() - if err != nil { - return nil, err - } - ns, err := namespaces.NamespaceRequired(ctx) - if err != nil { - return nil, err - } - runtime, err := s.readRuntime(path) - if err != nil { - return nil, err + log.L.Debugf("Stats, id: %s", r.ID) + if s.task == nil { + log.L.Debugf("Stats error, id: %s: container not created", r.ID) + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") } - rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) - stats, err := rs.Stats(ctx, s.id) + stats, err := s.task.Runtime().Stats(ctx, s.id) if err != nil { + log.L.Debugf("Stats error, id: %s: %v", r.ID, err) return nil, err } @@ -611,55 +736,58 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. // as runc. // // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 - data, err := typeurl.MarshalAny(&cgroups.Metrics{ - CPU: &cgroups.CPUStat{ - Usage: &cgroups.CPUUsage{ + metrics := &cgroupsstats.Metrics{ + CPU: &cgroupsstats.CPUStat{ + Usage: &cgroupsstats.CPUUsage{ Total: stats.Cpu.Usage.Total, Kernel: stats.Cpu.Usage.Kernel, User: stats.Cpu.Usage.User, PerCPU: stats.Cpu.Usage.Percpu, }, - Throttling: &cgroups.Throttle{ + Throttling: &cgroupsstats.Throttle{ Periods: stats.Cpu.Throttling.Periods, ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods, ThrottledTime: stats.Cpu.Throttling.ThrottledTime, }, }, - Memory: &cgroups.MemoryStat{ + Memory: &cgroupsstats.MemoryStat{ Cache: stats.Memory.Cache, - Usage: &cgroups.MemoryEntry{ + Usage: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Usage.Limit, Usage: stats.Memory.Usage.Usage, Max: stats.Memory.Usage.Max, Failcnt: stats.Memory.Usage.Failcnt, }, - Swap: &cgroups.MemoryEntry{ + Swap: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Swap.Limit, Usage: stats.Memory.Swap.Usage, Max: stats.Memory.Swap.Max, Failcnt: stats.Memory.Swap.Failcnt, }, - Kernel: &cgroups.MemoryEntry{ + Kernel: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Kernel.Limit, Usage: stats.Memory.Kernel.Usage, Max: stats.Memory.Kernel.Max, Failcnt: stats.Memory.Kernel.Failcnt, }, - KernelTCP: &cgroups.MemoryEntry{ + KernelTCP: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.KernelTCP.Limit, Usage: stats.Memory.KernelTCP.Usage, Max: stats.Memory.KernelTCP.Max, Failcnt: stats.Memory.KernelTCP.Failcnt, }, }, - Pids: &cgroups.PidsStat{ + Pids: &cgroupsstats.PidsStat{ Current: stats.Pids.Current, Limit: stats.Pids.Limit, }, - }) + } + data, err := typeurl.MarshalAny(metrics) if err != nil { + log.L.Debugf("Stats error, id: %s: %v", r.ID, err) return nil, err } + log.L.Debugf("Stats success, id: %s: %+v", r.ID, data) return &taskAPI.StatsResponse{ Stats: data, }, nil @@ -672,6 +800,8 @@ func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*ty // Wait waits for a process to exit. func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) { + log.L.Debugf("Wait, id: %s, execID: %s", r.ID, r.ExecID) + p, err := s.getProcess(r.ExecID) if err != nil { return nil, err @@ -687,21 +817,22 @@ func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.Wa }, nil } -func (s *service) processExits() { +func (s *service) processExits(ctx context.Context) { for e := range s.ec { - s.checkProcesses(e) + s.checkProcesses(ctx, e) } } -func (s *service) checkProcesses(e proc.Exit) { +func (s *service) checkProcesses(ctx context.Context, e proc.Exit) { // TODO(random-liu): Add `shouldKillAll` logic if container pid // namespace is supported. for _, p := range s.allProcesses() { if p.ID() == e.ID { if ip, ok := p.(*proc.Init); ok { // Ensure all children are killed. - if err := ip.KillAll(s.context); err != nil { - log.G(s.context).WithError(err).WithField("id", ip.ID()). + log.L.Debugf("Container init process exited, killing all container processes") + if err := ip.KillAll(ctx); err != nil { + log.G(ctx).WithError(err).WithField("id", ip.ID()). Error("failed to kill init's children") } } @@ -737,7 +868,7 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er if p == nil { return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) } - ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + ps, err := p.Runtime().Ps(ctx, id) if err != nil { return nil, err } @@ -748,11 +879,9 @@ func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, er return pids, nil } -func (s *service) forward(publisher shim.Publisher) { +func (s *service) forward(ctx context.Context, publisher shim.Publisher) { for e := range s.events { - ctx, cancel := context.WithTimeout(s.context, 5*time.Second) err := publisher.Publish(ctx, getTopic(e), e) - cancel() if err != nil { // Should not happen. panic(fmt.Errorf("post event: %w", err)) @@ -790,12 +919,12 @@ func getTopic(e interface{}) string { case *events.TaskExecStarted: return runtime.TaskExecStartedEventTopic default: - log.L.Printf("no topic for type %#v", e) + log.L.Infof("no topic for type %#v", e) } return runtime.TaskUnknownTopic } -func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) { +func newInit(path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options, rootfs string) (*proc.Init, error) { spec, err := utils.ReadSpec(r.Bundle) if err != nil { return nil, fmt.Errorf("read oci spec: %w", err) @@ -803,7 +932,7 @@ func newInit(ctx context.Context, path, workDir, namespace string, platform stdi if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { return nil, fmt.Errorf("update volume annotations: %w", err) } - runsc.FormatLogPath(r.ID, options.RunscConfig) + runsc.FormatRunscLogPath(r.ID, options.RunscConfig) runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) p := proc.New(r.ID, runtime, stdio.Stdio{ Stdin: r.Stdin, @@ -815,8 +944,8 @@ func newInit(ctx context.Context, path, workDir, namespace string, platform stdi p.Platform = platform p.Rootfs = rootfs p.WorkDir = workDir - p.IoUID = int(options.IoUid) - p.IoGID = int(options.IoGid) + p.IoUID = int(options.IoUID) + p.IoGID = int(options.IoGID) p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox p.UserLog = utils.UserLogPath(spec) p.Monitor = reaper.Default diff --git a/pkg/shim/v2/state.go b/pkg/shim/v2/state.go new file mode 100644 index 000000000..1f4be33d3 --- /dev/null +++ b/pkg/shim/v2/state.go @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "encoding/json" + "io/ioutil" + "path/filepath" +) + +const filename = "state.json" + +// state holds information needed between shim invocations. +type state struct { + // Rootfs is the full path to the location rootfs was mounted. + Rootfs string `json:"rootfs"` + + // Options is the configuration loaded from config.toml. + Options options `json:"options"` +} + +func (s state) load(path string) error { + data, err := ioutil.ReadFile(filepath.Join(path, filename)) + if err != nil { + return err + } + return json.Unmarshal(data, &s) +} + +func (s state) save(path string) error { + data, err := json.Marshal(&s) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(path, filename), data, 0644) +} diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index ae0fe1522..48bcdd62b 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -5,10 +5,6 @@ package(licenses = ["notice"]) go_library( name = "sleep", srcs = [ - "commit_amd64.s", - "commit_arm64.s", - "commit_asm.go", - "commit_noasm.go", "sleep_unsafe.go", ], visibility = ["//:sandbox"], diff --git a/pkg/sleep/commit_amd64.s b/pkg/sleep/commit_amd64.s deleted file mode 100644 index bc4ac2c3c..000000000 --- a/pkg/sleep/commit_amd64.s +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -#define preparingG 1 - -// See commit_noasm.go for a description of commitSleep. -// -// func commitSleep(g uintptr, waitingG *uintptr) bool -TEXT ·commitSleep(SB),NOSPLIT,$0-24 - MOVQ waitingG+8(FP), CX - MOVQ g+0(FP), DX - - // Store the G in waitingG if it's still preparingG. If it's anything - // else it means a waker has aborted the sleep. - MOVQ $preparingG, AX - LOCK - CMPXCHGQ DX, 0(CX) - - SETEQ AX - MOVB AX, ret+16(FP) - - RET diff --git a/pkg/sleep/commit_arm64.s b/pkg/sleep/commit_arm64.s deleted file mode 100644 index d0ef15b20..000000000 --- a/pkg/sleep/commit_arm64.s +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" - -#define preparingG 1 - -// See commit_noasm.go for a description of commitSleep. -// -// func commitSleep(g uintptr, waitingG *uintptr) bool -TEXT ·commitSleep(SB),NOSPLIT,$0-24 - MOVD waitingG+8(FP), R0 - MOVD $preparingG, R1 - MOVD G+0(FP), R2 - - // Store the G in waitingG if it's still preparingG. If it's anything - // else it means a waker has aborted the sleep. -again: - LDAXR (R0), R3 - CMP R1, R3 - BNE ok - STLXR R2, (R0), R3 - CBNZ R3, again -ok: - CSET EQ, R0 - MOVB R0, ret+16(FP) - RET diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go deleted file mode 100644 index f59061f37..000000000 --- a/pkg/sleep/commit_noasm.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !race -// +build !amd64,!arm64 - -package sleep - -import "sync/atomic" - -// commitSleep signals to wakers that the given g is now sleeping. Wakers can -// then fetch it and wake it. -// -// The commit may fail if wakers have been asserted after our last check, in -// which case they will have set s.waitingG to zero. -// -// It is written in assembly because it is called from g0, so it doesn't have -// a race context. -func commitSleep(g uintptr, waitingG *uintptr) bool { - // Try to store the G so that wakers know who to wake. - return atomic.CompareAndSwapUintptr(waitingG, preparingG, g) -} diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 19bce2afb..c44206b1e 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.11 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - // Package sleep allows goroutines to efficiently sleep on multiple sources of // notifications (wakers). It offers O(1) complexity, which is different from // multi-channel selects which have O(n) complexity (where n is the number of @@ -91,12 +86,6 @@ var ( assertedSleeper Sleeper ) -//go:linkname gopark runtime.gopark -func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason uint8, traceEv byte, traceskip int) - -//go:linkname goready runtime.goready -func goready(g uintptr, traceskip int) - // Sleeper allows a goroutine to sleep and receive wake up notifications from // Wakers in an efficient way. // @@ -189,7 +178,7 @@ func (s *Sleeper) nextWaker(block bool) *Waker { // See:runtime2.go in the go runtime package for // the values to pass as the waitReason here. const waitReasonSelect = 9 - gopark(commitSleep, &s.waitingG, waitReasonSelect, traceEvGoBlockSelect, 0) + sync.Gopark(commitSleep, unsafe.Pointer(&s.waitingG), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0) } // Pull the shared list out and reverse it in the local @@ -212,6 +201,18 @@ func (s *Sleeper) nextWaker(block bool) *Waker { return w } +// commitSleep signals to wakers that the given g is now sleeping. Wakers can +// then fetch it and wake it. +// +// The commit may fail if wakers have been asserted after our last check, in +// which case they will have set s.waitingG to zero. +// +//go:norace +//go:nosplit +func commitSleep(g uintptr, waitingG unsafe.Pointer) bool { + return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(waitingG), preparingG, g) +} + // Fetch fetches the next wake-up notification. If a notification is immediately // available, it is returned right away. Otherwise, the behavior depends on the // value of 'block': if true, the current goroutine blocks until a notification @@ -311,7 +312,7 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) { case 0, preparingG: default: // We managed to get a G. Wake it up. - goready(g, 0) + sync.Goready(g, 0) } } diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 089b3bbef..92c51879b 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "pending_list", - out = "pending_list.go", - package = "state", - prefix = "pending", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*objectEncodeState", - "ElementMapper": "pendingMapper", - "Linker": "*pendingEntry", - }, -) - -go_template_instance( name = "deferred_list", out = "deferred_list.go", package = "state", @@ -83,7 +70,6 @@ go_library( "deferred_list.go", "encode.go", "encode_unsafe.go", - "pending_list.go", "state.go", "state_norace.go", "state_race.go", diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 89467ca8e..e519ddeca 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -21,6 +21,7 @@ import ( "math" "reflect" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c // For the purposes of this function, a child object is either a field within a // struct or an array element, with one such indirection per element in // path. The returned value may be an unexported field, so it may not be -// directly assignable. See unsafePointerTo. +// directly assignable. See decode_unsafe.go. func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { // See wire.Ref.Dots. The path here is specified in reverse order. for i := len(path) - 1; i >= 0; i-- { @@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // Normal assignment: authoritative only if no dots. v := ds.register(x, obj.Type().Elem()) - if v.IsValid() { - obj.Set(unsafePointerTo(v)) - } + obj.Set(reflectValueRWAddr(v)) case wire.Bool: obj.SetBool(bool(x)) case wire.Int: @@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // contents will still be filled in later on. typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. v := ds.register(&x.Ref, typ) - obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity))) case *wire.Array: ds.decodeArray(ods, obj, x) case *wire.Struct: @@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) { ds.pending.PushBack(rootOds) // Read the number of objects. - lastID, object, err := ReadHeader(ds.r) + numObjects, object, err := ReadHeader(ds.r) if err != nil { Failf("header error: %w", err) } @@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) { var ( encoded wire.Object ods *objectDecodeState - id = objectID(1) + id objectID tid = typeID(1) ) if err := safely(func() { // Decode all objects in the stream. // - // Note that the structure of this decoding loop should match - // the raw decoding loop in printer.go. - for id <= objectID(lastID) { - // Unmarshal the object. + // Note that the structure of this decoding loop should match the raw + // decoding loop in state/pretty/pretty.printer.printStream(). + for i := uint64(0); i < numObjects; { + // Unmarshal either a type object or object ID. encoded = wire.Load(ds.r) - - // Is this a type object? Handle inline. - if wt, ok := encoded.(*wire.Type); ok { - ds.types.Register(wt) + switch we := encoded.(type) { + case *wire.Type: + ds.types.Register(we) tid++ encoded = nil continue + case wire.Uint: + id = objectID(we) + i++ + // Unmarshal and resolve the actual object. + encoded = wire.Load(ds.r) + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } + // For error handling. + ods = nil + encoded = nil + default: + Failf("wanted type or object ID, got %#v", encoded) } - - // Actually resolve the object. - ods = ds.lookup(id) - if ods != nil { - // Decode the object. - ds.decodeObject(ods, ods.obj, encoded) - } else { - // If an object hasn't had interest registered - // previously or isn't yet valid, we deferred - // decoding until interest is registered. - ds.deferred[id] = encoded - } - - // For error handling. - ods = nil - encoded = nil - id++ } }); err != nil { // Include as much information as we can, taking into account @@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) { if ods != nil { Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) } else if encoded != nil { - Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) + Failf("error decoding from %#v: %w", encoded, err) } else { Failf("general decoding error: %w", err) } } // Check if we have any deferred objects. + numDeferred := 0 for id, encoded := range ds.deferred { - // Shoud never happen, the graph was bogus. - Failf("still have deferred objects: one is ID %d, %#v", id, encoded) + numDeferred++ + if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 { + typ := ds.types.LookupType(typeID(s.TypeID)) + log.Warningf("unused deferred object: ID %d, type %v", id, typ) + } else { + log.Warningf("unused deferred object: ID %d, %#v", id, encoded) + } + } + if numDeferred != 0 { + Failf("still had %d deferred objects", numDeferred) } // Scan and fire all callbacks. We iterate over the list of incomplete diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go index d048f61a1..f1208e2a2 100644 --- a/pkg/state/decode_unsafe.go +++ b/pkg/state/decode_unsafe.go @@ -15,13 +15,62 @@ package state import ( + "fmt" "reflect" + "runtime" "unsafe" ) -// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on -// values representing unexported fields. This bypasses visibility, but not -// type safety. -func unsafePointerTo(obj reflect.Value) reflect.Value { +// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned +// reflect.Value is usable in assignments even if obj was obtained by the use +// of unexported struct fields. +// +// Preconditions: obj.CanAddr(). +func reflectValueRWAddr(obj reflect.Value) reflect.Value { return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) } + +// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the +// returned reflect.Value is usable in assignments even if obj was obtained by +// the use of unexported struct fields. +// +// Preconditions: +// * arr.Kind() == reflect.Array. +// * i, j, k >= 0. +// * i <= j <= k <= arr.Len(). +func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value { + if arr.Kind() != reflect.Array { + panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array)) + } + if i < 0 || j < 0 || k < 0 { + panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k)) + } + if i > j { + panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j)) + } + if j > k { + panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k)) + } + if k > arr.Len() { + panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len())) + } + + sliceTyp := reflect.SliceOf(arr.Type().Elem()) + if i == arr.Len() { + // By precondition, i == j == k == arr.Len(). + return reflect.MakeSlice(sliceTyp, 0, 0) + } + slh := reflect.SliceHeader{ + // reflect.Value.CanAddr() == false for arrays, so we need to get the + // address from the first element of the array. + Data: arr.Index(i).UnsafeAddr(), + Len: j - i, + Cap: k - i, + } + slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem() + // Before slobj is constructed, arr holds the only pointer-typed pointer to + // the array since reflect.SliceHeader.Data is a uintptr, so arr must be + // kept alive. + runtime.KeepAlive(arr) + return slobj +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 92fcad4e9..560e7c2a3 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -17,13 +17,14 @@ package state import ( "context" "reflect" + "sort" "gvisor.dev/gvisor/pkg/state/wire" ) // objectEncodeState the type and identity of an object occupying a memory // address range. This is the value type for addrSet, and the intrusive entry -// for the pending and deferred lists. +// for the deferred list. type objectEncodeState struct { // id is the assigned ID for this object. id objectID @@ -47,7 +48,6 @@ type objectEncodeState struct { // references may be updated directly and automatically. refs []*wire.Ref - pendingEntry deferredEntry } @@ -93,9 +93,15 @@ type encodeState struct { // serialized. pendingTypes []wire.Type - // pending is the list of objects to be serialized. Serialization does + // pending maps object IDs to objects to be serialized. Serialization does // not actually occur until the full object graph is computed. - pending pendingList + pending map[objectID]*objectEncodeState + + // encodedStructs maps reflect.Values representing structs to previous + // encodings of those structs. This is necessary to avoid duplicate calls + // to SaverLoader.StateSave() that may result in multiple calls to + // Sink.SaveValue() for a given field, resulting in object duplication. + encodedStructs map[reflect.Value]*wire.Struct // stats tracks time data. stats Stats @@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { // depending on this value knows there's nothing there. return } - if seg, _ := es.values.Find(addr); seg.Ok() { + seg, gap := es.values.Find(addr) + if seg.Ok() { // Ensure the map types match. existing := seg.Value() if existing.obj.Type() != obj.Type() { @@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { } // Record the map. + r := addrRange{addr, addr + 1} oes := &objectEncodeState{ id: es.nextID(), obj: obj, how: encodeMapAsValue, } - es.values.Add(addrRange{addr, addr + 1}, oes) - es.pending.PushBack(oes) + // Use Insert instead of InsertWithoutMergingUnchecked when race + // detection is enabled to get additional sanity-checking from Merge. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes es.deferred.PushBack(oes) // See above: no ref recording. @@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { obj: obj, } es.zeroValues[typ] = oes - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) } @@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { size = 1 // See above. } - // Calculate the container. end := addr + size r := addrRange{addr, end} - if seg, _ := es.values.Find(addr); seg.Ok() { + seg := es.values.LowerBoundSegment(addr) + var ( + oes *objectEncodeState + gap addrGapIterator + ) + + // Does at least one previously-registered object overlap this one? + if seg.Ok() && seg.Start() < end { existing := seg.Value() - switch { - case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): - // The object is a perfect match. Happy path. Avoid the - // traversal and just return directly. We don't need to - // encode the type information or any dots here. + + if seg.Range() == r && typ == existing.obj.Type() { + // This exact object is already registered. Avoid the traversal and + // just return directly. We don't need to encode the type + // information or any dots here. ref.Root = wire.Uint(existing.id) existing.refs = append(existing.refs, ref) return + } - case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): - // The previously registered object is larger than - // this, no need to update. But we expect some - // traversal below. + if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) { + // This object is contained within a previously-registered object. + // Perform traversal from the container to the new object. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr) + ref.Type = es.findType(existing.obj.Type()) + existing.refs = append(existing.refs, ref) + return + } - case seg.Start() == addr && seg.End() == end: - if !isSameSizeParent(obj, existing.obj.Type()) { - break // Needs traversal. + // This object contains one or more previously-registered objects. + // Remove them and update existing references to use the new one. + oes := &objectEncodeState{ + // Reuse the root ID of the first contained element. + id: existing.id, + obj: obj, + } + type elementEncodeState struct { + addr uintptr + typ reflect.Type + refs []*wire.Ref + } + var ( + elems []elementEncodeState + gap addrGapIterator + ) + for { + // Each contained object should be completely contained within + // this one. + if raceEnabled && !r.IsSupersetOf(seg.Range()) { + Failf("containing object %#v does not contain existing object %#v", obj, existing.obj) } - fallthrough // Needs update. - - case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): - // Update the object and redo the encoding. - old := existing.obj - existing.obj = obj + elems = append(elems, elementEncodeState{ + addr: seg.Start(), + typ: existing.obj.Type(), + refs: existing.refs, + }) + delete(es.pending, existing.id) es.deferred.Remove(existing) - es.deferred.PushBack(existing) - - // The previously registered object is superseded by - // this new object. We are guaranteed to not have any - // mergeable neighbours in this segment set. - if !raceEnabled { - seg.SetRangeUnchecked(r) - } else { - // Add extra paranoid. This will be statically - // removed at compile time unless a race build. - es.values.Remove(seg) - es.values.Add(r, existing) - seg = es.values.LowerBoundSegment(addr) + gap = es.values.Remove(seg) + seg = gap.NextSegment() + if !seg.Ok() || seg.Start() >= end { + break } - - // Compute the traversal required & update references. - dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) - wt := es.findType(obj.Type()) - for _, ref := range existing.refs { + existing = seg.Value() + } + wt := es.findType(typ) + for _, elem := range elems { + dots := traverse(typ, elem.typ, addr, elem.addr) + for _, ref := range elem.refs { + ref.Root = wire.Uint(oes.id) ref.Dots = append(ref.Dots, dots...) ref.Type = wt } - default: - // There is a non-sensical overlap. - Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + oes.refs = append(oes.refs, elem.refs...) } - - // Compute the new reference, record and return it. - ref.Root = wire.Uint(existing.id) - ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) - ref.Type = es.findType(obj.Type()) - existing.refs = append(existing.refs, ref) + // Finally register the new containing object. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) return } - // The only remaining case is a pointer value that doesn't overlap with - // any registered addresses. Create a new entry for it, and start - // tracking the first reference we just created. - oes := &objectEncodeState{ + // No existing object overlaps this one. Register a new object. + oes = &objectEncodeState{ id: es.nextID(), obj: obj, } + if seg.Ok() { + gap = seg.PrevGap() + } else { + gap = es.values.LastGap() + } if !raceEnabled { - es.values.AddWithoutMerging(r, oes) + es.values.InsertWithoutMergingUnchecked(gap, r, oes) } else { - // Merges should never happen. This is just enabled extra - // sanity checks because the Merge function below will panic. - es.values.Add(r, oes) + es.values.Insert(gap, r, oes) } - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) ref.Root = wire.Uint(oes.id) oes.refs = append(oes.refs, ref) @@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) { // encodeStruct encodes a composite object. func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + if s, ok := es.encodedStructs[obj]; ok { + *dest = s + return + } + s := &wire.Struct{} + *dest = s + es.encodedStructs[obj] = s + // Ensure that the obj is addressable. There are two cases when it is // not. First, is when this is dispatched via SaveValue. Second, when // this is a map key as a struct. Either way, we need to make a copy to @@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { obj = localObj.Elem() } - // Prepare the value. - s := &wire.Struct{} - *dest = s - // Look the type up in the database. te, ok := es.types.Lookup(obj.Type()) if te == nil { @@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) { Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) } - // Check that items are pending. - if es.pending.Front() == nil { + // Check that we have objects to serialize. + if len(es.pending) == 0 { Failf("pending is empty?") } - // Write the header with the number of objects. Note that there is no - // way that es.lastID could conflict with objectID, which would - // indicate that an impossibly large encoding. - if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + // Write the header with the number of objects. + if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil { Failf("error writing header: %w", err) } // Serialize all pending types and pending objects. Note that we don't // bother removing from this list as we walk it because that just // wastes time. It will not change after this point. - var id objectID if err := safely(func() { for _, wt := range es.pendingTypes { // Encode the type. wire.Save(es.w, &wt) } - for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { - id++ // First object is 1. - if oes.id != id { - Failf("expected id %d, got %d", id, oes.id) - } - - // Marshall the object. + // Emit objects in ID order. + ids := make([]objectID, 0, len(es.pending)) + for id := range es.pending { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + for _, id := range ids { + // Encode the id. + wire.Save(es.w, wire.Uint(id)) + // Marshal the object. + oes := es.pending[id] wire.Save(es.w, oes.encoded) } }); err != nil { // Include the object and the error. Failf("error serializing object %#v: %w", oes.encoded, err) } - - // Check what we wrote. - if id != es.lastID { - Failf("expected %d objects, wrote %d", es.lastID, id) - } } // objectFlag indicates that the length is a # of objects, rather than a raw @@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error { }) } -// pendingMapper is for the pending list. -type pendingMapper struct{} - -func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // deferredMapper is for the deferred list. type deferredMapper struct{} diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go index 887f453a9..c6e8bb31d 100644 --- a/pkg/state/pretty/pretty.go +++ b/pkg/state/pretty/pretty.go @@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { buf.WriteString(typ) buf.WriteString(")(") buf.WriteString(baseRef) + buf.WriteString(")") for _, component := range x.Dots { switch v := component.(type) { case *wire.FieldName: @@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) } } - buf.WriteString(")") fullRef = buf.String() } if p.html { @@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { // Note that this loop must match the general structure of the // loop in decode.go. But we don't register type information, // etc. and just print the raw structures. + type objectAndID struct { + id uint64 + obj wire.Object + } var ( tid uint64 = 1 - objects []wire.Object + objects []objectAndID ) - for oid := uint64(1); oid <= length; { - // Unmarshal the object. + for i := uint64(0); i < length; { + // Unmarshal either a type object or object ID. encoded := wire.Load(r) - - // Is this a type? - if typ, ok := encoded.(*wire.Type); ok { + switch we := encoded.(type) { + case *wire.Type: str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dt%d", graph, tid) - p.typeSpecs[tag] = typ + p.typeSpecs[tag] = we if p.html { // See below. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) @@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { return err } tid++ - continue + case wire.Uint: + // Unmarshal the actual object. + objects = append(objects, objectAndID{ + id: uint64(we), + obj: wire.Load(r), + }) + i++ + default: + return fmt.Errorf("wanted type or object ID, got %#v", encoded) } - - // Otherwise, it is a node. - objects = append(objects, encoded) - oid++ } - for i, encoded := range objects { - // oid starts at 1. - oid := i + 1 + for _, objAndID := range objects { // Format the node. - str, _ := p.format(graph, 0, encoded) - tag := fmt.Sprintf("g%dr%d", graph, oid) + str, _ := p.format(graph, 0, objAndID.obj) + tag := fmt.Sprintf("g%dr%d", graph, objAndID.id) if p.html { // Create a little tag with an anchor next to it for linking. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) diff --git a/pkg/state/state.go b/pkg/state/state.go index acb629969..6b8540f03 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error { func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. es := encodeState{ - ctx: ctx, - w: w, - types: makeTypeEncodeDatabase(), - zeroValues: make(map[reflect.Type]*objectEncodeState), + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), + pending: make(map[objectID]*objectEncodeState), + encodedStructs: make(map[reflect.Value]*wire.Struct), } // Perform the encoding. diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go index d3931c952..2b1609af0 100644 --- a/pkg/state/tests/integer_test.go +++ b/pkg/state/tests/integer_test.go @@ -20,21 +20,21 @@ import ( ) var ( - allIntTs = []int{-1, 0, 1} - allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} - allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} - allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} - allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} - allUintTs = []uint{0, 1} - allUintptrs = []uintptr{0, 1, ^uintptr(0)} - allUint8s = []uint8{0, 1, math.MaxUint8} - allUint16s = []uint16{0, 1, math.MaxUint16} - allUint32s = []uint32{0, 1, math.MaxUint32} - allUint64s = []uint64{0, 1, math.MaxUint64} + allBasicInts = []int{-1, 0, 1} + allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} + allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} + allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} + allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} + allBasicUints = []uint{0, 1} + allUintptrs = []uintptr{0, 1, ^uintptr(0)} + allUint8s = []uint8{0, 1, math.MaxUint8} + allUint16s = []uint16{0, 1, math.MaxUint16} + allUint32s = []uint32{0, 1, math.MaxUint32} + allUint64s = []uint64{0, 1, math.MaxUint64} ) var allInts = flatten( - allIntTs, + allBasicInts, allInt8s, allInt16s, allInt32s, @@ -42,7 +42,7 @@ var allInts = flatten( ) var allUints = flatten( - allUintTs, + allBasicUints, allUintptrs, allUint8s, allUint16s, diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go index bd2c2b399..69143d194 100644 --- a/pkg/state/tests/struct.go +++ b/pkg/state/tests/struct.go @@ -54,12 +54,47 @@ type outerArray struct { } // +stateify savable +type outerSlice struct { + inner []inner +} + +// +stateify savable type inner struct { v int64 } // +stateify savable +type outerFieldValue struct { + inner innerFieldValue +} + +// +stateify savable +type innerFieldValue struct { + v int64 `state:".(*savedFieldValue)"` +} + +// +stateify savable +type savedFieldValue struct { + v int64 +} + +func (ifv *innerFieldValue) saveV() *savedFieldValue { + return &savedFieldValue{ifv.v} +} + +func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) { + ifv.v = sfv.v +} + +// +stateify savable type system struct { v1 interface{} v2 interface{} } + +// +stateify savable +type system3 struct { + v1 interface{} + v2 interface{} + v3 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go index de9d17aa7..c91c2c032 100644 --- a/pkg/state/tests/struct_test.go +++ b/pkg/state/tests/struct_test.go @@ -15,6 +15,7 @@ package tests import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/state" @@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) { } func TestEmbeddedPointers(t *testing.T) { - var ( - ofs outerSame - of1 outerFieldFirst - of2 outerFieldSecond - oa outerArray - ) + // Give each int64 a random value to prevent Go from using + // runtime.staticuint64s, which confounds tests for struct duplication. + magic := func() int64 { + for { + n := rand.Int63() + if n < 0 || n > 255 { + return n + } + } + } + + ofs := outerSame{inner{magic()}} + of1 := outerFieldFirst{inner{magic()}, magic()} + of2 := outerFieldSecond{magic(), inner{magic()}} + oa := outerArray{[2]inner{{magic()}, {magic()}}} + osl := outerSlice{oa.inner[:]} + ofv := outerFieldValue{innerFieldValue{magic()}} runTestCases(t, false, "embedded-pointers", []interface{}{ system{&ofs, &ofs.inner}, @@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) { system{&oa, &oa.inner[1]}, system{&oa.inner[0], &oa}, system{&oa.inner[1], &oa}, + system3{&oa, &oa.inner[0], &oa.inner[1]}, + system3{&oa, &oa.inner[1], &oa.inner[0]}, + system3{&oa.inner[0], &oa, &oa.inner[1]}, + system3{&oa.inner[1], &oa, &oa.inner[0]}, + system3{&oa.inner[0], &oa.inner[1], &oa}, + system3{&oa.inner[1], &oa.inner[0], &oa}, + system{&oa, &osl}, + system{&osl, &oa}, + system{&ofv, &ofv.inner}, + system{&ofv.inner, &ofv}, }) } diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 68535c3b1..28e62abbb 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -10,15 +10,34 @@ exports_files(["LICENSE"]) go_template( name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], + srcs = ["generic_atomicptr_unsafe.go"], types = [ "Value", ], ) go_template( + name = "generic_atomicptrmap", + srcs = ["generic_atomicptrmap_unsafe.go"], + opt_consts = [ + "ShardOrder", + ], + opt_types = [ + "Hasher", + ], + types = [ + "Key", + "Value", + ], + deps = [ + ":sync", + "//pkg/gohacks", + ], +) + +go_template( name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], + srcs = ["generic_seqatomic_unsafe.go"], types = [ "Value", ], @@ -31,18 +50,26 @@ go_library( name = "sync", srcs = [ "aliases.go", - "memmove_unsafe.go", + "checklocks_off_unsafe.go", + "checklocks_on_unsafe.go", + "goyield_go113_unsafe.go", + "goyield_unsafe.go", "mutex_unsafe.go", "nocopy.go", "norace_unsafe.go", + "race_amd64.s", + "race_arm64.s", "race_unsafe.go", + "runtime_unsafe.go", "rwmutex_unsafe.go", "seqcount.go", - "spin_unsafe.go", "sync.go", ], marshal = False, stateify = False, + deps = [ + "//pkg/goid", + ], ) go_test( diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmaptest/BUILD new file mode 100644 index 000000000..3f71ae97d --- /dev/null +++ b/pkg/sync/atomicptrmaptest/BUILD @@ -0,0 +1,57 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], +) + +go_template_instance( + name = "test_atomicptrmap", + out = "test_atomicptrmap_unsafe.go", + package = "atomicptrmap", + prefix = "test", + template = "//pkg/sync:generic_atomicptrmap", + types = { + "Key": "int64", + "Value": "testValue", + }, +) + +go_template_instance( + name = "test_atomicptrmap_sharded", + out = "test_atomicptrmap_sharded_unsafe.go", + consts = { + "ShardOrder": "4", + }, + package = "atomicptrmap", + prefix = "test", + suffix = "Sharded", + template = "//pkg/sync:generic_atomicptrmap", + types = { + "Key": "int64", + "Value": "testValue", + }, +) + +go_library( + name = "atomicptrmap", + testonly = 1, + srcs = [ + "atomicptrmap.go", + "test_atomicptrmap_sharded_unsafe.go", + "test_atomicptrmap_unsafe.go", + ], + deps = [ + "//pkg/gohacks", + "//pkg/sync", + ], +) + +go_test( + name = "atomicptrmap_test", + size = "small", + srcs = ["atomicptrmap_test.go"], + library = ":atomicptrmap", + deps = ["//pkg/sync"], +) diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap.go b/pkg/sync/atomicptrmaptest/atomicptrmap.go new file mode 100644 index 000000000..867821ce9 --- /dev/null +++ b/pkg/sync/atomicptrmaptest/atomicptrmap.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package atomicptrmap instantiates generic_atomicptrmap for testing. +package atomicptrmap + +type testValue struct { + val int +} diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go new file mode 100644 index 000000000..75a9997ef --- /dev/null +++ b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go @@ -0,0 +1,635 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomicptrmap + +import ( + "context" + "fmt" + "math/rand" + "reflect" + "runtime" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +func TestConsistencyWithGoMap(t *testing.T) { + const maxKey = 16 + var vals [4]*testValue + for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { + vals[i] = new(testValue) + } + var ( + m = make(map[int64]*testValue) + apm testAtomicPtrMap + ) + for i := 0; i < 100000; i++ { + // Apply a random operation to both m and apm and expect them to have + // the same result. Bias toward CompareAndSwap, which has the most + // cases; bias away from Range and RangeRepeatable, which are + // relatively expensive. + switch rand.Intn(10) { + case 0, 1: // Load + key := rand.Int63n(maxKey) + want := m[key] + got := apm.Load(key) + t.Logf("Load(%d) = %p", key, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 2, 3: // Swap + key := rand.Int63n(maxKey) + val := vals[rand.Intn(len(vals))] + want := m[key] + if val != nil { + m[key] = val + } else { + delete(m, key) + } + got := apm.Swap(key, val) + t.Logf("Swap(%d, %p) = %p", key, val, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 4, 5, 6, 7: // CompareAndSwap + key := rand.Int63n(maxKey) + oldVal := vals[rand.Intn(len(vals))] + newVal := vals[rand.Intn(len(vals))] + want := m[key] + if want == oldVal { + if newVal != nil { + m[key] = newVal + } else { + delete(m, key) + } + } + got := apm.CompareAndSwap(key, oldVal, newVal) + t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 8: // Range + got := make(map[int64]*testValue) + var ( + haveDup = false + dup int64 + ) + apm.Range(func(key int64, val *testValue) bool { + if _, ok := got[key]; ok && !haveDup { + haveDup = true + dup = key + } + got[key] = val + return true + }) + t.Logf("Range() = %v", got) + if !reflect.DeepEqual(got, m) { + t.Fatalf("got %v, wanted %v", got, m) + } + if haveDup { + t.Fatalf("got duplicate key %d", dup) + } + case 9: // RangeRepeatable + got := make(map[int64]*testValue) + apm.RangeRepeatable(func(key int64, val *testValue) bool { + got[key] = val + return true + }) + t.Logf("RangeRepeatable() = %v", got) + if !reflect.DeepEqual(got, m) { + t.Fatalf("got %v, wanted %v", got, m) + } + } + } +} + +func TestConcurrentHeterogeneous(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var ( + apm testAtomicPtrMap + wg sync.WaitGroup + ) + defer func() { + cancel() + wg.Wait() + }() + + possibleKeyValuePairs := make(map[int64]map[*testValue]struct{}) + addKeyValuePair := func(key int64, val *testValue) { + values := possibleKeyValuePairs[key] + if values == nil { + values = make(map[*testValue]struct{}) + possibleKeyValuePairs[key] = values + } + values[val] = struct{}{} + } + + const numValuesPerKey = 4 + + // These goroutines use keys not used by any other goroutine. + const numPrivateKeys = 3 + for i := 0; i < numPrivateKeys; i++ { + key := int64(i) + var vals [numValuesPerKey]*testValue + for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { + val := new(testValue) + vals[i] = val + addKeyValuePair(key, val) + } + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + var stored *testValue + for ctx.Err() == nil { + switch r.Intn(4) { + case 0: + got := apm.Load(key) + if got != stored { + t.Errorf("Load(%d): got %p, wanted %p", key, got, stored) + return + } + case 1: + val := vals[r.Intn(len(vals))] + want := stored + stored = val + got := apm.Swap(key, val) + if got != want { + t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want) + return + } + case 2, 3: + oldVal := vals[r.Intn(len(vals))] + newVal := vals[r.Intn(len(vals))] + want := stored + if stored == oldVal { + stored = newVal + } + got := apm.CompareAndSwap(key, oldVal, newVal) + if got != want { + t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want) + return + } + } + } + }() + } + + // These goroutines share a small set of keys. + const numSharedKeys = 2 + var ( + sharedKeys [numSharedKeys]int64 + sharedValues = make(map[int64][]*testValue) + sharedValuesSet = make(map[int64]map[*testValue]struct{}) + ) + for i := range sharedKeys { + key := int64(numPrivateKeys + i) + sharedKeys[i] = key + vals := make([]*testValue, numValuesPerKey) + valsSet := make(map[*testValue]struct{}) + for j := range vals { + val := new(testValue) + vals[j] = val + valsSet[val] = struct{}{} + addKeyValuePair(key, val) + } + sharedValues[key] = vals + sharedValuesSet[key] = valsSet + } + randSharedValue := func(r *rand.Rand, key int64) *testValue { + vals := sharedValues[key] + return vals[r.Intn(len(vals))] + } + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + for ctx.Err() == nil { + keyIndex := r.Intn(len(sharedKeys)) + key := sharedKeys[keyIndex] + var ( + op string + got *testValue + ) + switch r.Intn(4) { + case 0: + op = "Load" + got = apm.Load(key) + case 1: + op = "Swap" + got = apm.Swap(key, randSharedValue(r, key)) + case 2, 3: + op = "CompareAndSwap" + got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key)) + } + if got != nil { + valsSet := sharedValuesSet[key] + if _, ok := valsSet[got]; !ok { + t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet) + return + } + } + } + }() + } + + // This goroutine repeatedly searches for unused keys. + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + for ctx.Err() == nil { + key := -1 - r.Int63() + if got := apm.Load(key); got != nil { + t.Errorf("Load(%d): got %p, wanted nil", key, got) + } + } + }() + + // This goroutine repeatedly calls RangeRepeatable() and checks that each + // key corresponds to an expected value. + wg.Add(1) + go func() { + defer wg.Done() + abort := false + for !abort && ctx.Err() == nil { + apm.RangeRepeatable(func(key int64, val *testValue) bool { + values, ok := possibleKeyValuePairs[key] + if !ok { + t.Errorf("RangeRepeatable: got invalid key %d", key) + abort = true + return false + } + if _, ok := values[val]; !ok { + t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values) + abort = true + return false + } + return true + }) + } + }() + + // Finally, the main goroutine spins for the length of the test calling + // Range() and checking that each key that it observes is unique and + // corresponds to an expected value. + seenKeys := make(map[int64]struct{}) + const testDuration = 5 * time.Second + end := time.Now().Add(testDuration) + abort := false + for time.Now().Before(end) { + apm.Range(func(key int64, val *testValue) bool { + values, ok := possibleKeyValuePairs[key] + if !ok { + t.Errorf("Range: got invalid key %d", key) + abort = true + return false + } + if _, ok := values[val]; !ok { + t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values) + abort = true + return false + } + if _, ok := seenKeys[key]; ok { + t.Errorf("Range: got duplicate key %d", key) + abort = true + return false + } + seenKeys[key] = struct{}{} + return true + }) + if abort { + break + } + for k := range seenKeys { + delete(seenKeys, k) + } + } +} + +type benchmarkableMap interface { + Load(key int64) *testValue + Store(key int64, val *testValue) + LoadOrStore(key int64, val *testValue) (*testValue, bool) + Delete(key int64) +} + +// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map. +type rwMutexMap struct { + mu sync.RWMutex + m map[int64]*testValue +} + +func (m *rwMutexMap) Load(key int64) *testValue { + m.mu.RLock() + defer m.mu.RUnlock() + return m.m[key] +} + +func (m *rwMutexMap) Store(key int64, val *testValue) { + m.mu.Lock() + defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[int64]*testValue) + } + m.m[key] = val +} + +func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + m.mu.Lock() + defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[int64]*testValue) + } + if oldVal, ok := m.m[key]; ok { + return oldVal, true + } + m.m[key] = val + return val, false +} + +func (m *rwMutexMap) Delete(key int64) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.m, key) +} + +// syncMap implements benchmarkableMap for a sync.Map. +type syncMap struct { + m sync.Map +} + +func (m *syncMap) Load(key int64) *testValue { + val, ok := m.m.Load(key) + if !ok { + return nil + } + return val.(*testValue) +} + +func (m *syncMap) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + actual, loaded := m.m.LoadOrStore(key, val) + return actual.(*testValue), loaded +} + +func (m *syncMap) Delete(key int64) { + m.m.Delete(key) +} + +// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap. +type benchmarkableAtomicPtrMap struct { + m testAtomicPtrMap +} + +func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue { + return m.m.Load(key) +} + +func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { + return prev, true + } + return val, false +} + +func (m *benchmarkableAtomicPtrMap) Delete(key int64) { + m.m.Store(key, nil) +} + +// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded. +type benchmarkableAtomicPtrMapSharded struct { + m testAtomicPtrMapSharded +} + +func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue { + return m.m.Load(key) +} + +func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { + return prev, true + } + return val, false +} + +func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) { + m.m.Store(key, nil) +} + +var mapImpls = [...]struct { + name string + ctor func() benchmarkableMap +}{ + { + name: "RWMutexMap", + ctor: func() benchmarkableMap { + return new(rwMutexMap) + }, + }, + { + name: "SyncMap", + ctor: func() benchmarkableMap { + return new(syncMap) + }, + }, + { + name: "AtomicPtrMap", + ctor: func() benchmarkableMap { + return new(benchmarkableAtomicPtrMap) + }, + }, + { + name: "AtomicPtrMapSharded", + ctor: func() benchmarkableMap { + return new(benchmarkableAtomicPtrMapSharded) + }, + }, +} + +func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + for i := 0; i < b.N; i++ { + m.Delete(int64(i)) + } +} + +func BenchmarkStoreDelete(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkStoreDelete(b, mapImpl.ctor) + }) + } +} + +func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.LoadOrStore(int64(i), val) + } + for i := 0; i < b.N; i++ { + m.Delete(int64(i)) + } +} + +func BenchmarkLoadOrStoreDelete(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLoadOrStoreDelete(b, mapImpl.ctor) + }) + } +} + +func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Load(int64(i)) + } +} + +func BenchmarkLookupPositive(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLookupPositive(b, mapImpl.ctor) + }) + } +} + +func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Load(int64(-1 - i)) + } +} + +func BenchmarkLookupNegative(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLookupNegative(b, mapImpl.ctor) + }) + } +} + +type benchmarkConcurrentOptions struct { + // loadsPerMutationPair is the number of map lookups between each + // insertion/deletion pair. + loadsPerMutationPair int + + // If changeKeys is true, the keys used by each goroutine change between + // iterations of the test. + changeKeys bool +} + +func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) { + var ( + started sync.WaitGroup + workers sync.WaitGroup + ) + started.Add(1) + + m := mapCtor() + val := &testValue{} + // Insert a large number of unused elements into the map so that used + // elements are distributed throughout memory. + for i := 0; i < 10000; i++ { + m.Store(int64(-1-i), val) + } + // n := ceil(b.N / (opts.loadsPerMutationPair + 2)) + n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2) + for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ { + workerID := i + workers.Add(1) + go func() { + defer workers.Done() + started.Wait() + for i := 0; i < n; i++ { + var key int64 + if opts.changeKeys { + key = int64(workerID*n + i) + } else { + key = int64(workerID) + } + m.LoadOrStore(key, val) + for j := 0; j < opts.loadsPerMutationPair; j++ { + m.Load(key) + } + m.Delete(key) + } + }() + } + + b.ResetTimer() + started.Done() + workers.Wait() +} + +func BenchmarkConcurrent(b *testing.B) { + changeKeysChoices := [...]struct { + name string + val bool + }{ + {"FixedKeys", false}, + {"ChangingKeys", true}, + } + writePcts := [...]struct { + name string + loadsPerMutationPair int + }{ + {"1PercentWrites", 198}, + {"10PercentWrites", 18}, + {"50PercentWrites", 2}, + } + for _, changeKeys := range changeKeysChoices { + for _, writePct := range writePcts { + for _, mapImpl := range mapImpls { + name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name) + b.Run(name, func(b *testing.B) { + benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{ + loadsPerMutationPair: writePct.loadsPerMutationPair, + changeKeys: changeKeys.val, + }) + }) + } + } + } +} diff --git a/pkg/sync/checklocks_off_unsafe.go b/pkg/sync/checklocks_off_unsafe.go new file mode 100644 index 000000000..62c81b149 --- /dev/null +++ b/pkg/sync/checklocks_off_unsafe.go @@ -0,0 +1,18 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !checklocks + +package sync + +import ( + "unsafe" +) + +func noteLock(l unsafe.Pointer) { +} + +func noteUnlock(l unsafe.Pointer) { +} diff --git a/pkg/sync/checklocks_on_unsafe.go b/pkg/sync/checklocks_on_unsafe.go new file mode 100644 index 000000000..24f933ed1 --- /dev/null +++ b/pkg/sync/checklocks_on_unsafe.go @@ -0,0 +1,108 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build checklocks + +package sync + +import ( + "fmt" + "strings" + "sync" + "unsafe" + + "gvisor.dev/gvisor/pkg/goid" +) + +// gLocks contains metadata about the locks held by a goroutine. +type gLocks struct { + locksHeld []unsafe.Pointer +} + +// map[goid int]*gLocks +// +// Each key may only be written by the G with the goid it refers to. +// +// Note that entries are not evicted when a G exit, causing unbounded growth +// with new G creation / destruction. If this proves problematic, entries could +// be evicted when no locks are held at the expense of more allocations when +// taking top-level locks. +var locksHeld sync.Map + +func getGLocks() *gLocks { + id := goid.Get() + + var locks *gLocks + if l, ok := locksHeld.Load(id); ok { + locks = l.(*gLocks) + } else { + locks = &gLocks{ + // Initialize space for a few locks. + locksHeld: make([]unsafe.Pointer, 0, 8), + } + locksHeld.Store(id, locks) + } + + return locks +} + +func noteLock(l unsafe.Pointer) { + locks := getGLocks() + + for _, lock := range locks.locksHeld { + if lock == l { + panic(fmt.Sprintf("Deadlock on goroutine %d! Double lock of %p: %+v", goid.Get(), l, locks)) + } + } + + // Commit only after checking for panic conditions so that this lock + // isn't on the list if the above panic is recovered. + locks.locksHeld = append(locks.locksHeld, l) +} + +func noteUnlock(l unsafe.Pointer) { + locks := getGLocks() + + if len(locks.locksHeld) == 0 { + panic(fmt.Sprintf("Unlock of %p on goroutine %d without any locks held! All locks:\n%s", l, goid.Get(), dumpLocks())) + } + + // Search backwards since callers are most likely to unlock in LIFO order. + length := len(locks.locksHeld) + for i := length - 1; i >= 0; i-- { + if l == locks.locksHeld[i] { + copy(locks.locksHeld[i:length-1], locks.locksHeld[i+1:length]) + // Clear last entry to ensure addr can be GC'd. + locks.locksHeld[length-1] = nil + locks.locksHeld = locks.locksHeld[:length-1] + return + } + } + + panic(fmt.Sprintf("Unlock of %p on goroutine %d without matching lock! All locks:\n%s", l, goid.Get(), dumpLocks())) +} + +func dumpLocks() string { + var s strings.Builder + locksHeld.Range(func(key, value interface{}) bool { + goid := key.(int64) + locks := value.(*gLocks) + + // N.B. accessing gLocks of another G is fundamentally racy. + + fmt.Fprintf(&s, "goroutine %d:\n", goid) + if len(locks.locksHeld) == 0 { + fmt.Fprintf(&s, "\t<none>\n") + } + for _, lock := range locks.locksHeld { + fmt.Fprintf(&s, "\t%p\n", lock) + } + fmt.Fprintf(&s, "\n") + + return true + }) + + return s.String() +} diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/generic_atomicptr_unsafe.go index 525c4beed..82b6df18c 100644 --- a/pkg/sync/atomicptr_unsafe.go +++ b/pkg/sync/generic_atomicptr_unsafe.go @@ -3,9 +3,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package template doesn't exist. This file must be instantiated using the +// Package seqatomic doesn't exist. This file must be instantiated using the // go_template_instance rule in tools/go_generics/defs.bzl. -package template +package seqatomic import ( "sync/atomic" diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/generic_atomicptrmap_unsafe.go new file mode 100644 index 000000000..c70dda6dd --- /dev/null +++ b/pkg/sync/generic_atomicptrmap_unsafe.go @@ -0,0 +1,503 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package atomicptrmap doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package atomicptrmap + +import ( + "reflect" + "runtime" + "sync/atomic" + "unsafe" + + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sync" +) + +// Key is a required type parameter. +type Key struct{} + +// Value is a required type parameter. +type Value struct{} + +const ( + // ShardOrder is an optional parameter specifying the base-2 log of the + // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce + // unnecessary synchronization between unrelated concurrent operations, + // improving performance for write-heavy workloads, but increase memory + // usage for small maps. + ShardOrder = 0 +) + +// Hasher is an optional type parameter. If Hasher is provided, it must define +// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps. +type Hasher struct { + defaultHasher +} + +// defaultHasher is the default Hasher. This indirection exists because +// defaultHasher must exist even if a custom Hasher is provided, to prevent the +// Go compiler from complaining about defaultHasher's unused imports. +type defaultHasher struct { + fn func(unsafe.Pointer, uintptr) uintptr + seed uintptr +} + +// Init initializes the Hasher. +func (h *defaultHasher) Init() { + h.fn = sync.MapKeyHasher(map[Key]*Value(nil)) + h.seed = sync.RandUintptr() +} + +// Hash returns the hash value for the given Key. +func (h *defaultHasher) Hash(key Key) uintptr { + return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed) +} + +var hasher Hasher + +func init() { + hasher.Init() +} + +// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are +// safe for concurrent use from multiple goroutines without additional +// synchronization. +// +// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for +// use. AtomicPtrMaps must not be copied after first use. +// +// sync.Map may be faster than AtomicPtrMap if most operations on the map are +// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in +// other circumstances. +type AtomicPtrMap struct { + // AtomicPtrMap is implemented as a hash table with the following + // properties: + // + // * Collisions are resolved with quadratic probing. Of the two major + // alternatives, Robin Hood linear probing makes it difficult for writers + // to execute in parallel, and bucketing is less effective in Go due to + // lack of SIMD. + // + // * The table is optionally divided into shards indexed by hash to further + // reduce unnecessary synchronization. + + shards [1 << ShardOrder]apmShard +} + +func (m *AtomicPtrMap) shard(hash uintptr) *apmShard { + // Go defines right shifts >= width of shifted unsigned operand as 0, so + // this is correct even if ShardOrder is 0 (although nogo complains because + // nogo is dumb). + const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder + index := hash >> indexLSB + return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{})))) +} + +type apmShard struct { + apmShardMutationData + _ [apmShardMutationDataPadding]byte + apmShardLookupData + _ [apmShardLookupDataPadding]byte +} + +type apmShardMutationData struct { + dirtyMu sync.Mutex // serializes slot transitions out of empty + dirty uintptr // # slots with val != nil + count uintptr // # slots with val != nil and val != tombstone() + rehashMu sync.Mutex // serializes rehashing +} + +type apmShardLookupData struct { + seq sync.SeqCount // allows atomic reads of slots+mask + slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq + mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq +} + +const ( + cacheLineBytes = 64 + // Cache line padding is enabled if sharding is. + apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise + // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) % + // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes). + apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1) + apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding + apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1) + apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding + + // These define fractional thresholds for when apmShard.rehash() is called + // (i.e. the load factor) and when it rehases to a larger table + // respectively. They are chosen such that the rehash threshold = the + // expansion threshold + 1/2, so that when reuse of deleted slots is rare + // or non-existent, rehashing occurs after the insertion of at least 1/2 + // the table's size in new entries, which is acceptably infrequent. + apmRehashThresholdNum = 2 + apmRehashThresholdDen = 3 + apmExpansionThresholdNum = 1 + apmExpansionThresholdDen = 6 +) + +type apmSlot struct { + // slot states are indicated by val: + // + // * Empty: val == nil; key is meaningless. May transition to full or + // evacuated with dirtyMu locked. + // + // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val + // is the Value mapped to key. May transition to deleted or evacuated. + // + // * Deleted: val == tombstone(); key is still immutable. key is mapped to + // no Value. May transition to full or evacuated. + // + // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on + // slots that have already been moved, requiring readers to wait for + // rehashing to complete and use the new table. Terminal state. + // + // Note that once val is non-nil, it cannot become nil again. That is, the + // transition from empty to non-empty is irreversible for a given slot; + // the only way to create more empty slots is by rehashing. + val unsafe.Pointer + key Key +} + +func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot { + return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{}))) +} + +var tombstoneObj byte + +func tombstone() unsafe.Pointer { + return unsafe.Pointer(&tombstoneObj) +} + +var evacuatedObj byte + +func evacuated() unsafe.Pointer { + return unsafe.Pointer(&evacuatedObj) +} + +// Load returns the Value stored in m for key. +func (m *AtomicPtrMap) Load(key Key) *Value { + hash := hasher.Hash(key) + shard := m.shard(hash) + +retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + return nil + } + + i := hash & mask + inc := uintptr(1) + for { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil { + // Empty slot; end of probe sequence. + return nil + } + if slotVal == evacuated() { + // Racing with rehashing. + goto retry + } + if slot.key == key { + if slotVal == tombstone() { + return nil + } + return (*Value)(slotVal) + } + i = (i + inc) & mask + inc++ + } +} + +// Store stores the Value val for key. +func (m *AtomicPtrMap) Store(key Key, val *Value) { + m.maybeCompareAndSwap(key, false, nil, val) +} + +// Swap stores the Value val for key and returns the previously-mapped Value. +func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value { + return m.maybeCompareAndSwap(key, false, nil, val) +} + +// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it +// stores the Value newVal for key. CompareAndSwap returns the previous Value +// stored for key, whether or not it stores newVal. +func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value { + return m.maybeCompareAndSwap(key, true, oldVal, newVal) +} + +func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value { + hash := hasher.Hash(key) + shard := m.shard(hash) + oldVal := tombstone() + if typedOldVal != nil { + oldVal = unsafe.Pointer(typedOldVal) + } + newVal := tombstone() + if typedNewVal != nil { + newVal = unsafe.Pointer(typedNewVal) + } + +retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + if (compare && oldVal != tombstone()) || newVal == tombstone() { + return nil + } + // Need to allocate a table before insertion. + shard.rehash(nil) + goto retry + } + + i := hash & mask + inc := uintptr(1) + for { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil { + if (compare && oldVal != tombstone()) || newVal == tombstone() { + return nil + } + // Try to grab this slot for ourselves. + shard.dirtyMu.Lock() + slotVal = atomic.LoadPointer(&slot.val) + if slotVal == nil { + // Check if we need to rehash before dirtying a slot. + if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum { + shard.dirtyMu.Unlock() + shard.rehash(slots) + goto retry + } + slot.key = key + atomic.StorePointer(&slot.val, newVal) // transitions slot to full + shard.dirty++ + atomic.AddUintptr(&shard.count, 1) + shard.dirtyMu.Unlock() + return nil + } + // Raced with another store; the slot is no longer empty. Continue + // with the new value of slotVal since we may have raced with + // another store of key. + shard.dirtyMu.Unlock() + } + if slotVal == evacuated() { + // Racing with rehashing. + goto retry + } + if slot.key == key { + // We're reusing an existing slot, so rehashing isn't necessary. + for { + if (compare && oldVal != slotVal) || newVal == slotVal { + if slotVal == tombstone() { + return nil + } + return (*Value)(slotVal) + } + if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) { + if slotVal == tombstone() { + atomic.AddUintptr(&shard.count, 1) + return nil + } + if newVal == tombstone() { + atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */) + } + return (*Value)(slotVal) + } + slotVal = atomic.LoadPointer(&slot.val) + if slotVal == evacuated() { + goto retry + } + } + } + // This produces a triangular number sequence of offsets from the + // initially-probed position. + i = (i + inc) & mask + inc++ + } +} + +// rehash is marked nosplit to avoid preemption during table copying. +//go:nosplit +func (shard *apmShard) rehash(oldSlots unsafe.Pointer) { + shard.rehashMu.Lock() + defer shard.rehashMu.Unlock() + + if shard.slots != oldSlots { + // Raced with another call to rehash(). + return + } + + // Determine the size of the new table. Constraints: + // + // * The size of the table must be a power of two to ensure that every slot + // is visitable by every probe sequence under quadratic probing with + // triangular numbers. + // + // * The size of the table cannot decrease because even if shard.count is + // currently smaller than shard.dirty, concurrent stores that reuse + // existing slots can drive shard.count back up to a maximum of + // shard.dirty. + newSize := uintptr(8) // arbitrary initial size + if oldSlots != nil { + oldSize := shard.mask + 1 + newSize = oldSize + if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum { + newSize *= 2 + } + } + + // Allocate the new table. + newSlotsSlice := make([]apmSlot, newSize) + newSlotsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&newSlotsSlice)) + newSlots := unsafe.Pointer(newSlotsReflect.Data) + runtime.KeepAlive(newSlotsSlice) + newMask := newSize - 1 + + // Start a writer critical section now so that racing users of the old + // table that observe evacuated() wait for the new table. (But lock dirtyMu + // first since doing so may block, which we don't want to do during the + // writer critical section.) + shard.dirtyMu.Lock() + shard.seq.BeginWrite() + + if oldSlots != nil { + realCount := uintptr(0) + // Copy old entries to the new table. + oldMask := shard.mask + for i := uintptr(0); i <= oldMask; i++ { + oldSlot := apmSlotAt(oldSlots, i) + val := atomic.SwapPointer(&oldSlot.val, evacuated()) + if val == nil || val == tombstone() { + continue + } + hash := hasher.Hash(oldSlot.key) + j := hash & newMask + inc := uintptr(1) + for { + newSlot := apmSlotAt(newSlots, j) + if newSlot.val == nil { + newSlot.val = val + newSlot.key = oldSlot.key + break + } + j = (j + inc) & newMask + inc++ + } + realCount++ + } + // Update dirty to reflect that tombstones were not copied to the new + // table. Use realCount since a concurrent mutator may not have updated + // shard.count yet. + shard.dirty = realCount + } + + // Switch to the new table. + atomic.StorePointer(&shard.slots, newSlots) + atomic.StoreUintptr(&shard.mask, newMask) + + shard.seq.EndWrite() + shard.dirtyMu.Unlock() +} + +// Range invokes f on each Key-Value pair stored in m. If any call to f returns +// false, Range stops iteration and returns. +// +// Range does not necessarily correspond to any consistent snapshot of the +// Map's contents: no Key will be visited more than once, but if the Value for +// any Key is stored or deleted concurrently, Range may reflect any mapping for +// that Key from any point during the Range call. +// +// f must not call other methods on m. +func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) { + for si := 0; si < len(m.shards); si++ { + shard := &m.shards[si] + if !shard.doRange(f) { + return + } + } +} + +func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool { + // We have to lock rehashMu because if we handled races with rehashing by + // retrying, f could see the same key twice. + shard.rehashMu.Lock() + defer shard.rehashMu.Unlock() + slots := shard.slots + if slots == nil { + return true + } + mask := shard.mask + for i := uintptr(0); i <= mask; i++ { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil || slotVal == tombstone() { + continue + } + if !f(slot.key, (*Value)(slotVal)) { + return false + } + } + return true +} + +// RangeRepeatable is like Range, but: +// +// * RangeRepeatable may visit the same Key multiple times in the presence of +// concurrent mutators, possibly passing different Values to f in different +// calls. +// +// * It is safe for f to call other methods on m. +func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) { + for si := 0; si < len(m.shards); si++ { + shard := &m.shards[si] + + retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + continue + } + + for i := uintptr(0); i <= mask; i++ { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == evacuated() { + goto retry + } + if slotVal == nil || slotVal == tombstone() { + continue + } + if !f(slot.key, (*Value)(slotVal)) { + return + } + } + } +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go index 2184cb5ab..82b676abf 100644 --- a/pkg/sync/seqatomic_unsafe.go +++ b/pkg/sync/generic_seqatomic_unsafe.go @@ -3,25 +3,17 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package template doesn't exist. This file must be instantiated using the +// Package seqatomic doesn't exist. This file must be instantiated using the // go_template_instance rule in tools/go_generics/defs.bzl. -package template +package seqatomic import ( - "fmt" - "reflect" - "strings" "unsafe" "gvisor.dev/gvisor/pkg/sync" ) // Value is a required type parameter. -// -// Value must not contain any pointers, including interface objects, function -// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs -// containing any of the above. An init() function will panic if this property -// does not hold. type Value struct{} // SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race @@ -55,12 +47,3 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) ok = seq.ReadOk(epoch) return } - -func init() { - var val Value - typ := reflect.TypeOf(val) - name := typ.Name() - if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 { - panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) - } -} diff --git a/pkg/sync/goyield_go113_unsafe.go b/pkg/sync/goyield_go113_unsafe.go new file mode 100644 index 000000000..8aee0d455 --- /dev/null +++ b/pkg/sync/goyield_go113_unsafe.go @@ -0,0 +1,18 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.14 + +package sync + +import ( + "runtime" +) + +func goyield() { + // goyield is not available until Go 1.14. + runtime.Gosched() +} diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/goyield_unsafe.go index cafb2d065..672ee274d 100644 --- a/pkg/sync/spin_unsafe.go +++ b/pkg/sync/goyield_unsafe.go @@ -3,7 +3,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.13 +// +build go1.14 // +build !go1.17 // Check go:linkname function signatures when updating Go version. @@ -14,11 +14,5 @@ import ( _ "unsafe" // for go:linkname ) -//go:linkname canSpin sync.runtime_canSpin -func canSpin(i int) bool - -//go:linkname doSpin sync.runtime_doSpin -func doSpin() - //go:linkname goyield runtime.goyield func goyield() diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go deleted file mode 100644 index f5e630009..000000000 --- a/pkg/sync/memmove_unsafe.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - -package sync - -import ( - "unsafe" -) - -//go:linkname memmove runtime.memmove -//go:noescape -func memmove(to, from unsafe.Pointer, n uintptr) - -// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad<T>, which can't -// define it because go_generics can't update the go:linkname annotation. -// Furthermore, go:linkname silently doesn't work if the local name is exported -// (this is of course undocumented), which is why this indirection is -// necessary. -func Memmove(to, from unsafe.Pointer, n uintptr) { - memmove(to, from, n) -} diff --git a/pkg/sync/mutex_test.go b/pkg/sync/mutex_test.go index 0838248b4..4fb51a8ab 100644 --- a/pkg/sync/mutex_test.go +++ b/pkg/sync/mutex_test.go @@ -32,11 +32,11 @@ func TestStructSize(t *testing.T) { func TestFieldValues(t *testing.T) { var m Mutex m.Lock() - if got := *m.state(); got != mutexLocked { + if got := *m.m.state(); got != mutexLocked { t.Errorf("got locked sync.Mutex.state = %d, want = %d", got, mutexLocked) } m.Unlock() - if got := *m.state(); got != mutexUnlocked { + if got := *m.m.state(); got != mutexUnlocked { t.Errorf("got unlocked sync.Mutex.state = %d, want = %d", got, mutexUnlocked) } } diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go index f4c2e9642..21084b857 100644 --- a/pkg/sync/mutex_unsafe.go +++ b/pkg/sync/mutex_unsafe.go @@ -17,8 +17,9 @@ import ( "unsafe" ) -// Mutex is a try lock. -type Mutex struct { +// CrossGoroutineMutex is equivalent to Mutex, but it need not be unlocked by a +// the same goroutine that locked the mutex. +type CrossGoroutineMutex struct { sync.Mutex } @@ -27,7 +28,7 @@ type syncMutex struct { sema uint32 } -func (m *Mutex) state() *int32 { +func (m *CrossGoroutineMutex) state() *int32 { return &(*syncMutex)(unsafe.Pointer(&m.Mutex)).state } @@ -36,9 +37,9 @@ const ( mutexLocked = 1 ) -// TryLock tries to aquire the mutex. It returns true if it succeeds and false +// TryLock tries to acquire the mutex. It returns true if it succeeds and false // otherwise. TryLock does not block. -func (m *Mutex) TryLock() bool { +func (m *CrossGoroutineMutex) TryLock() bool { if atomic.CompareAndSwapInt32(m.state(), mutexUnlocked, mutexLocked) { if RaceEnabled { RaceAcquire(unsafe.Pointer(&m.Mutex)) @@ -47,3 +48,43 @@ func (m *Mutex) TryLock() bool { } return false } + +// Mutex is a mutual exclusion lock. The zero value for a Mutex is an unlocked +// mutex. +// +// A Mutex must not be copied after first use. +// +// A Mutex must be unlocked by the same goroutine that locked it. This +// invariant is enforced with the 'checklocks' build tag. +type Mutex struct { + m CrossGoroutineMutex +} + +// Lock locks m. If the lock is already in use, the calling goroutine blocks +// until the mutex is available. +func (m *Mutex) Lock() { + noteLock(unsafe.Pointer(m)) + m.m.Lock() +} + +// Unlock unlocks m. +// +// Preconditions: +// * m is locked. +// * m was locked by this goroutine. +func (m *Mutex) Unlock() { + noteUnlock(unsafe.Pointer(m)) + m.m.Unlock() +} + +// TryLock tries to acquire the mutex. It returns true if it succeeds and false +// otherwise. TryLock does not block. +func (m *Mutex) TryLock() bool { + // Note lock first to enforce proper locking even if unsuccessful. + noteLock(unsafe.Pointer(m)) + locked := m.m.TryLock() + if !locked { + noteUnlock(unsafe.Pointer(m)) + } + return locked +} diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go index 006055dd6..70b5f3a5e 100644 --- a/pkg/sync/norace_unsafe.go +++ b/pkg/sync/norace_unsafe.go @@ -8,6 +8,7 @@ package sync import ( + "sync/atomic" "unsafe" ) @@ -33,3 +34,13 @@ func RaceRelease(addr unsafe.Pointer) { // RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. func RaceReleaseMerge(addr unsafe.Pointer) { } + +// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to +// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector. +// This is necessary when implementing gopark callbacks, since no race context +// is available during their execution. +func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool { + // Use atomic.CompareAndSwapUintptr outside of race builds for + // inlinability. + return atomic.CompareAndSwapUintptr(ptr, old, new) +} diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/sync/race_amd64.s index 5e216b045..57bc0ec79 100644 --- a/pkg/syncevent/waiter_amd64.s +++ b/pkg/sync/race_amd64.s @@ -12,21 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build race +// +build amd64 + #include "textflag.h" -// See waiter_noasm_unsafe.go for a description of waiterUnlock. -// -// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool -TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 +// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool +TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25 MOVQ ptr+0(FP), DI - MOVQ wg+8(FP), SI + MOVQ old+8(FP), AX + MOVQ new+16(FP), SI - MOVQ $·preparingG(SB), AX LOCK - CMPXCHGQ DI, 0(SI) + CMPXCHGQ SI, 0(DI) SETEQ AX - MOVB AX, ret+16(FP) + MOVB AX, ret+24(FP) RET diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/sync/race_arm64.s index f4c06f194..88f091fda 100644 --- a/pkg/syncevent/waiter_arm64.s +++ b/pkg/sync/race_arm64.s @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build race +// +build arm64 + #include "textflag.h" -// See waiter_noasm_unsafe.go for a description of waiterUnlock. -// -// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool -TEXT ·waiterUnlock(SB),NOSPLIT,$0-24 - MOVD wg+8(FP), R0 - MOVD $·preparingG(SB), R1 - MOVD ptr+0(FP), R2 +// func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool +TEXT ·RaceUncheckedAtomicCompareAndSwapUintptr(SB),NOSPLIT,$0-25 + MOVD ptr+0(FP), R0 + MOVD old+8(FP), R1 + MOVD new+16(FP), R1 again: LDAXR (R0), R3 CMP R1, R3 @@ -29,6 +30,6 @@ again: CBNZ R3, again ok: CSET EQ, R0 - MOVB R0, ret+16(FP) + MOVB R0, ret+24(FP) RET diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go index 31d8fa9a6..59985c270 100644 --- a/pkg/sync/race_unsafe.go +++ b/pkg/sync/race_unsafe.go @@ -39,3 +39,9 @@ func RaceRelease(addr unsafe.Pointer) { func RaceReleaseMerge(addr unsafe.Pointer) { runtime.RaceReleaseMerge(addr) } + +// RaceUncheckedAtomicCompareAndSwapUintptr is equivalent to +// sync/atomic.CompareAndSwapUintptr, but is not checked by the race detector. +// This is necessary when implementing gopark callbacks, since no race context +// is available during their execution. +func RaceUncheckedAtomicCompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go new file mode 100644 index 000000000..e925e2e5b --- /dev/null +++ b/pkg/sync/runtime_unsafe.go @@ -0,0 +1,129 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.17 + +// Check function signatures and constants when updating Go version. + +package sync + +import ( + "fmt" + "reflect" + "unsafe" +) + +// Note that go:linkname silently doesn't work if the local name is exported, +// necessitating an indirection for exported functions. + +// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>. +// +//go:nosplit +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +// Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock); +// if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g) +// is called. unlockf and its callees must be nosplit and norace, since stack +// splitting and race context are not available where it is called. +// +//go:nosplit +func Gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int) { + gopark(unlockf, lock, reason, traceEv, traceskip) +} + +//go:linkname gopark runtime.gopark +func gopark(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int) + +// Goready is runtime.goready. +// +//go:nosplit +func Goready(gp uintptr, traceskip int) { + goready(gp, traceskip) +} + +//go:linkname goready runtime.goready +func goready(gp uintptr, traceskip int) + +// Values for the reason argument to gopark, from Go's src/runtime/runtime2.go. +const ( + WaitReasonSelect uint8 = 9 +) + +// Values for the traceEv argument to gopark, from Go's src/runtime/trace.go. +const ( + TraceEvGoBlockSelect byte = 24 +) + +// Rand32 returns a non-cryptographically-secure random uint32. +func Rand32() uint32 { + return fastrand() +} + +// Rand64 returns a non-cryptographically-secure random uint64. +func Rand64() uint64 { + return uint64(fastrand())<<32 | uint64(fastrand()) +} + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 + +// RandUintptr returns a non-cryptographically-secure random uintptr. +func RandUintptr() uintptr { + if unsafe.Sizeof(uintptr(0)) == 4 { + return uintptr(Rand32()) + } + return uintptr(Rand64()) +} + +// MapKeyHasher returns a hash function for pointers of m's key type. +// +// Preconditions: m must be a map. +func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr { + if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map { + panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp)) + } + mtyp := *(**maptype)(unsafe.Pointer(&m)) + return mtyp.hasher +} + +type maptype struct { + size uintptr + ptrdata uintptr + hash uint32 + tflag uint8 + align uint8 + fieldAlign uint8 + kind uint8 + equal func(unsafe.Pointer, unsafe.Pointer) bool + gcdata *byte + str int32 + ptrToThis int32 + key unsafe.Pointer + elem unsafe.Pointer + bucket unsafe.Pointer + hasher func(unsafe.Pointer, uintptr) uintptr + // more fields +} + +// These functions are only used within the sync package. + +//go:linkname semacquire sync.runtime_Semacquire +func semacquire(s *uint32) + +//go:linkname semrelease sync.runtime_Semrelease +func semrelease(s *uint32, handoff bool, skipframes int) + +//go:linkname canSpin sync.runtime_canSpin +func canSpin(i int) bool + +//go:linkname doSpin sync.runtime_doSpin +func doSpin() diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go index ce667e825..5ca96d12b 100644 --- a/pkg/sync/rwmutex_test.go +++ b/pkg/sync/rwmutex_test.go @@ -102,7 +102,7 @@ func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone c } for i := 0; i < 100; i++ { } - n = atomic.AddInt32(activity, -1) + atomic.AddInt32(activity, -1) rwm.RUnlock() } cdone <- true diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go index b3b4dee78..4cf3fcd6e 100644 --- a/pkg/sync/rwmutex_unsafe.go +++ b/pkg/sync/rwmutex_unsafe.go @@ -3,11 +3,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build go1.13 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - // This is mostly copied from the standard library's sync/rwmutex.go. // // Happens-before relationships indicated to the race detector: @@ -23,16 +18,15 @@ import ( "unsafe" ) -//go:linkname runtimeSemacquire sync.runtime_Semacquire -func runtimeSemacquire(s *uint32) - -//go:linkname runtimeSemrelease sync.runtime_Semrelease -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) - -// RWMutex is identical to sync.RWMutex, but adds the DowngradeLock, -// TryLock and TryRLock methods. -type RWMutex struct { - w Mutex // held if there are pending writers +// CrossGoroutineRWMutex is equivalent to RWMutex, but it need not be unlocked +// by a the same goroutine that locked the mutex. +type CrossGoroutineRWMutex struct { + // w is held if there are pending writers + // + // We use CrossGoroutineMutex rather than Mutex because the lock + // annotation instrumentation in Mutex will trigger false positives in + // the race detector when called inside of RaceDisable. + w CrossGoroutineMutex writerSem uint32 // semaphore for writers to wait for completing readers readerSem uint32 // semaphore for readers to wait for completing writers readerCount int32 // number of pending readers @@ -43,7 +37,7 @@ const rwmutexMaxReaders = 1 << 30 // TryRLock locks rw for reading. It returns true if it succeeds and false // otherwise. It does not block. -func (rw *RWMutex) TryRLock() bool { +func (rw *CrossGoroutineRWMutex) TryRLock() bool { if RaceEnabled { RaceDisable() } @@ -67,13 +61,17 @@ func (rw *RWMutex) TryRLock() bool { } // RLock locks rw for reading. -func (rw *RWMutex) RLock() { +// +// It should not be used for recursive read locking; a blocked Lock call +// excludes new readers from acquiring the lock. See the documentation on the +// RWMutex type. +func (rw *CrossGoroutineRWMutex) RLock() { if RaceEnabled { RaceDisable() } if atomic.AddInt32(&rw.readerCount, 1) < 0 { // A writer is pending, wait for it. - runtimeSemacquire(&rw.readerSem) + semacquire(&rw.readerSem) } if RaceEnabled { RaceEnable() @@ -82,7 +80,10 @@ func (rw *RWMutex) RLock() { } // RUnlock undoes a single RLock call. -func (rw *RWMutex) RUnlock() { +// +// Preconditions: +// * rw is locked for reading. +func (rw *CrossGoroutineRWMutex) RUnlock() { if RaceEnabled { RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) RaceDisable() @@ -94,7 +95,7 @@ func (rw *RWMutex) RUnlock() { // A writer is pending. if atomic.AddInt32(&rw.readerWait, -1) == 0 { // The last reader unblocks the writer. - runtimeSemrelease(&rw.writerSem, false, 0) + semrelease(&rw.writerSem, false, 0) } } if RaceEnabled { @@ -104,7 +105,7 @@ func (rw *RWMutex) RUnlock() { // TryLock locks rw for writing. It returns true if it succeeds and false // otherwise. It does not block. -func (rw *RWMutex) TryLock() bool { +func (rw *CrossGoroutineRWMutex) TryLock() bool { if RaceEnabled { RaceDisable() } @@ -130,8 +131,9 @@ func (rw *RWMutex) TryLock() bool { return true } -// Lock locks rw for writing. -func (rw *RWMutex) Lock() { +// Lock locks rw for writing. If the lock is already locked for reading or +// writing, Lock blocks until the lock is available. +func (rw *CrossGoroutineRWMutex) Lock() { if RaceEnabled { RaceDisable() } @@ -141,7 +143,7 @@ func (rw *RWMutex) Lock() { r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders // Wait for active readers. if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { - runtimeSemacquire(&rw.writerSem) + semacquire(&rw.writerSem) } if RaceEnabled { RaceEnable() @@ -150,7 +152,10 @@ func (rw *RWMutex) Lock() { } // Unlock unlocks rw for writing. -func (rw *RWMutex) Unlock() { +// +// Preconditions: +// * rw is locked for writing. +func (rw *CrossGoroutineRWMutex) Unlock() { if RaceEnabled { RaceRelease(unsafe.Pointer(&rw.writerSem)) RaceRelease(unsafe.Pointer(&rw.readerSem)) @@ -163,7 +168,7 @@ func (rw *RWMutex) Unlock() { } // Unblock blocked readers, if any. for i := 0; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) + semrelease(&rw.readerSem, false, 0) } // Allow other writers to proceed. rw.w.Unlock() @@ -173,7 +178,10 @@ func (rw *RWMutex) Unlock() { } // DowngradeLock atomically unlocks rw for writing and locks it for reading. -func (rw *RWMutex) DowngradeLock() { +// +// Preconditions: +// * rw is locked for writing. +func (rw *CrossGoroutineRWMutex) DowngradeLock() { if RaceEnabled { RaceRelease(unsafe.Pointer(&rw.readerSem)) RaceDisable() @@ -186,7 +194,7 @@ func (rw *RWMutex) DowngradeLock() { // Unblock blocked readers, if any. Note that this loop starts as 1 since r // includes this goroutine. for i := 1; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) + semrelease(&rw.readerSem, false, 0) } // Allow other writers to proceed to rw.w.Lock(). Note that they will still // block on rw.writerSem since at least this reader exists, such that @@ -196,3 +204,91 @@ func (rw *RWMutex) DowngradeLock() { RaceEnable() } } + +// A RWMutex is a reader/writer mutual exclusion lock. The lock can be held by +// an arbitrary number of readers or a single writer. The zero value for a +// RWMutex is an unlocked mutex. +// +// A RWMutex must not be copied after first use. +// +// If a goroutine holds a RWMutex for reading and another goroutine might call +// Lock, no goroutine should expect to be able to acquire a read lock until the +// initial read lock is released. In particular, this prohibits recursive read +// locking. This is to ensure that the lock eventually becomes available; a +// blocked Lock call excludes new readers from acquiring the lock. +// +// A Mutex must be unlocked by the same goroutine that locked it. This +// invariant is enforced with the 'checklocks' build tag. +type RWMutex struct { + m CrossGoroutineRWMutex +} + +// TryRLock locks rw for reading. It returns true if it succeeds and false +// otherwise. It does not block. +func (rw *RWMutex) TryRLock() bool { + // Note lock first to enforce proper locking even if unsuccessful. + noteLock(unsafe.Pointer(rw)) + locked := rw.m.TryRLock() + if !locked { + noteUnlock(unsafe.Pointer(rw)) + } + return locked +} + +// RLock locks rw for reading. +// +// It should not be used for recursive read locking; a blocked Lock call +// excludes new readers from acquiring the lock. See the documentation on the +// RWMutex type. +func (rw *RWMutex) RLock() { + noteLock(unsafe.Pointer(rw)) + rw.m.RLock() +} + +// RUnlock undoes a single RLock call. +// +// Preconditions: +// * rw is locked for reading. +// * rw was locked by this goroutine. +func (rw *RWMutex) RUnlock() { + rw.m.RUnlock() + noteUnlock(unsafe.Pointer(rw)) +} + +// TryLock locks rw for writing. It returns true if it succeeds and false +// otherwise. It does not block. +func (rw *RWMutex) TryLock() bool { + // Note lock first to enforce proper locking even if unsuccessful. + noteLock(unsafe.Pointer(rw)) + locked := rw.m.TryLock() + if !locked { + noteUnlock(unsafe.Pointer(rw)) + } + return locked +} + +// Lock locks rw for writing. If the lock is already locked for reading or +// writing, Lock blocks until the lock is available. +func (rw *RWMutex) Lock() { + noteLock(unsafe.Pointer(rw)) + rw.m.Lock() +} + +// Unlock unlocks rw for writing. +// +// Preconditions: +// * rw is locked for writing. +// * rw was locked by this goroutine. +func (rw *RWMutex) Unlock() { + rw.m.Unlock() + noteUnlock(unsafe.Pointer(rw)) +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +// +// Preconditions: +// * rw is locked for writing. +func (rw *RWMutex) DowngradeLock() { + // No note change for DowngradeLock. + rw.m.DowngradeLock() +} diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go index 2c5d3df99..1f025f33c 100644 --- a/pkg/sync/seqcount.go +++ b/pkg/sync/seqcount.go @@ -6,8 +6,6 @@ package sync import ( - "fmt" - "reflect" "sync/atomic" ) @@ -27,9 +25,6 @@ import ( // - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other // operations to be made atomic with reads of SeqCount-protected data. // -// - SeqCount may be less flexible: as of this writing, SeqCount-protected data -// cannot include pointers. -// // - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected // data require instantiating function templates using go_generics (see // seqatomic.go). @@ -128,32 +123,3 @@ func (s *SeqCount) EndWrite() { panic("SeqCount.EndWrite outside writer critical section") } } - -// PointersInType returns a list of pointers reachable from values named -// valName of the given type. -// -// PointersInType is not exhaustive, but it is guaranteed that if typ contains -// at least one pointer, then PointersInTypeOf returns a non-empty list. -func PointersInType(typ reflect.Type, valName string) []string { - switch kind := typ.Kind(); kind { - case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return nil - - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: - return []string{valName} - - case reflect.Array: - return PointersInType(typ.Elem(), valName+"[]") - - case reflect.Struct: - var ptrs []string - for i, n := 0, typ.NumField(); i < n; i++ { - field := typ.Field(i) - ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) - } - return ptrs - - default: - return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} - } -} diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go index 6eb7b4b59..3f5592e3e 100644 --- a/pkg/sync/seqcount_test.go +++ b/pkg/sync/seqcount_test.go @@ -6,7 +6,6 @@ package sync import ( - "reflect" "testing" "time" ) @@ -99,55 +98,3 @@ func BenchmarkSeqCountReadUncontended(b *testing.B) { } }) } - -func TestPointersInType(t *testing.T) { - for _, test := range []struct { - name string // used for both test and value name - val interface{} - ptrs []string - }{ - { - name: "EmptyStruct", - val: struct{}{}, - }, - { - name: "Int", - val: int(0), - }, - { - name: "MixedStruct", - val: struct { - b bool - I int - ExportedPtr *struct{} - unexportedPtr *struct{} - arr [2]int - ptrArr [2]*int - nestedStruct struct { - nestedNonptr int - nestedPtr *int - } - structArr [1]struct { - nonptr int - ptr *int - } - }{}, - ptrs: []string{ - "MixedStruct.ExportedPtr", - "MixedStruct.unexportedPtr", - "MixedStruct.ptrArr[]", - "MixedStruct.nestedStruct.nestedPtr", - "MixedStruct.structArr[].ptr", - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - typ := reflect.TypeOf(test.val) - ptrs := PointersInType(typ, test.name) - t.Logf("Found pointers: %v", ptrs) - if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { - t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) - } - }) - } -} diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD index 0500a22cf..42c553308 100644 --- a/pkg/syncevent/BUILD +++ b/pkg/syncevent/BUILD @@ -9,10 +9,6 @@ go_library( "receiver.go", "source.go", "syncevent.go", - "waiter_amd64.s", - "waiter_arm64.s", - "waiter_asm_unsafe.go", - "waiter_noasm_unsafe.go", "waiter_unsafe.go", ], visibility = ["//:sandbox"], diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go deleted file mode 100644 index 0f74a689c..000000000 --- a/pkg/syncevent/waiter_noasm_unsafe.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// waiterUnlock is called from g0, so when the race detector is enabled, -// waiterUnlock must be implemented in assembly since no race context is -// available. -// -// +build !race -// +build !amd64,!arm64 - -package syncevent - -import ( - "sync/atomic" - "unsafe" -) - -// waiterUnlock is the "unlock function" passed to runtime.gopark by -// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g. -// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping -// should be aborted. -// -//go:nosplit -func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool { - // The only way this CAS can fail is if a call to Waiter.NotifyPending() - // has replaced *wg with nil, in which case we should not sleep. - return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), ptr) -} diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go index 518f18479..b6ed2852d 100644 --- a/pkg/syncevent/waiter_unsafe.go +++ b/pkg/syncevent/waiter_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.11 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package syncevent import ( @@ -26,17 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" ) -//go:linkname gopark runtime.gopark -func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int) - -//go:linkname goready runtime.goready -func goready(g unsafe.Pointer, traceskip int) - -const ( - waitReasonSelect = 9 // Go: src/runtime/runtime2.go - traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go -) - // Waiter allows a goroutine to block on pending events received by a Receiver. // // Waiter.Init() must be called before first use. @@ -45,20 +29,19 @@ type Waiter struct { // g is one of: // - // - nil: No goroutine is blocking in Wait. + // - 0: No goroutine is blocking in Wait. // - // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet + // - preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet // completed waiterUnlock(). Thus the wait can only be interrupted by - // replacing the value of g with nil (the G may not be in state Gwaiting - // yet, so we can't call goready.) + // replacing the value of g with 0 (the G may not be in state Gwaiting yet, + // so we can't call goready.) // // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the // goroutine blocked in Wait, which can only be woken by calling goready. - g unsafe.Pointer `state:"zerovalue"` + g uintptr `state:"zerovalue"` } -// Sentinel object for Waiter.g. -var preparingG struct{} +const preparingG = 1 // Init must be called before first use of w. func (w *Waiter) Init() { @@ -99,21 +82,29 @@ func (w *Waiter) WaitFor(es Set) Set { } // Indicate that we're preparing to go to sleep. - atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG)) + atomic.StoreUintptr(&w.g, preparingG) // If an event is pending, abort the sleep. if p := w.r.Pending(); p&es != NoEvents { - atomic.StorePointer(&w.g, nil) + atomic.StoreUintptr(&w.g, 0) return p } // If w.g is still preparingG (i.e. w.NotifyPending() has not been - // called or has not reached atomic.SwapPointer()), go to sleep until + // called or has not reached atomic.SwapUintptr()), go to sleep until // w.NotifyPending() => goready(). - gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0) + sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0) } } +//go:norace +//go:nosplit +func waiterCommit(g uintptr, wg unsafe.Pointer) bool { + // The only way this CAS can fail is if a call to Waiter.NotifyPending() + // has replaced *wg with nil, in which case we should not sleep. + return sync.RaceUncheckedAtomicCompareAndSwapUintptr((*uintptr)(wg), preparingG, g) +} + // Ack marks the given events as not pending. func (w *Waiter) Ack(es Set) { w.r.Ack(es) @@ -135,20 +126,20 @@ func (w *Waiter) WaitAndAckAll() Set { for { // Indicate that we're preparing to go to sleep. - atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG)) + atomic.StoreUintptr(&w.g, preparingG) // If an event is pending, abort the sleep. if w.r.Pending() != NoEvents { if p := w.r.PendingAndAckAll(); p != NoEvents { - atomic.StorePointer(&w.g, nil) + atomic.StoreUintptr(&w.g, 0) return p } } // If w.g is still preparingG (i.e. w.NotifyPending() has not been - // called or has not reached atomic.SwapPointer()), go to sleep until + // called or has not reached atomic.SwapUintptr()), go to sleep until // w.NotifyPending() => goready(). - gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0) + sync.Gopark(waiterCommit, unsafe.Pointer(&w.g), sync.WaitReasonSelect, sync.TraceEvGoBlockSelect, 0) // Check for pending events. We call PendingAndAckAll() directly now since // we only expect to be woken after events become pending. @@ -171,14 +162,14 @@ func (w *Waiter) NotifyPending() { // goroutine. NotifyPending is called after w.r.Pending() is updated, so // concurrent and future calls to w.Wait() will observe pending events and // abort sleeping. - if atomic.LoadPointer(&w.g) == nil { + if atomic.LoadUintptr(&w.g) == 0 { return } // Wake a sleeping G, or prevent a G that is preparing to sleep from doing // so. Swap is needed here to ensure that only one call to NotifyPending // calls goready. - if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) { - goready(g, 0) + if g := atomic.SwapUintptr(&w.g, 0); g > preparingG { + sync.Goready(g, 0) } } diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go index fc6ef60a1..77faa3670 100644 --- a/pkg/syserr/host_linux.go +++ b/pkg/syserr/host_linux.go @@ -32,7 +32,7 @@ var linuxHostTranslations [maxErrno]linuxHostTranslation // FromHost translates a syscall.Errno to a corresponding Error value. func FromHost(err syscall.Errno) *Error { - if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok { + if int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok { panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err)) } return linuxHostTranslations[err].err diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 5ae10939d..77c3c110c 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -15,6 +15,8 @@ package syserr import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -48,45 +50,56 @@ var ( ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM) ) -var netstackErrorTranslations = map[*tcpip.Error]*Error{ - tcpip.ErrUnknownProtocol: ErrUnknownProtocol, - tcpip.ErrUnknownNICID: ErrUnknownNICID, - tcpip.ErrUnknownDevice: ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID: ErrDuplicateNICID, - tcpip.ErrDuplicateAddress: ErrDuplicateAddress, - tcpip.ErrNoRoute: ErrNoRoute, - tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound: ErrAlreadyBound, - tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected: ErrAlreadyConnected, - tcpip.ErrNoPortAvailable: ErrNoPortAvailable, - tcpip.ErrPortInUse: ErrPortInUse, - tcpip.ErrBadLocalAddress: ErrBadLocalAddress, - tcpip.ErrClosedForSend: ErrClosedForSend, - tcpip.ErrClosedForReceive: ErrClosedForReceive, - tcpip.ErrWouldBlock: ErrWouldBlock, - tcpip.ErrConnectionRefused: ErrConnectionRefused, - tcpip.ErrTimeout: ErrTimeout, - tcpip.ErrAborted: ErrAborted, - tcpip.ErrConnectStarted: ErrConnectStarted, - tcpip.ErrDestinationRequired: ErrDestinationRequired, - tcpip.ErrNotSupported: ErrNotSupported, - tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported, - tcpip.ErrNotConnected: ErrNotConnected, - tcpip.ErrConnectionReset: ErrConnectionReset, - tcpip.ErrConnectionAborted: ErrConnectionAborted, - tcpip.ErrNoSuchFile: ErrNoSuchFile, - tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress: ErrHostDown, - tcpip.ErrBadAddress: ErrBadAddress, - tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, - tcpip.ErrMessageTooLong: ErrMessageTooLong, - tcpip.ErrNoBufferSpace: ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, - tcpip.ErrNotPermitted: ErrNotPermittedNet, - tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported, +var netstackErrorTranslations map[string]*Error + +func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) { + key := tcpipErr.String() + if _, ok := netstackErrorTranslations[key]; ok { + panic(fmt.Sprintf("duplicate error key: %s", key)) + } + netstackErrorTranslations[key] = netstackErr +} + +func init() { + netstackErrorTranslations = make(map[string]*Error) + addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol) + addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID) + addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice) + addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption) + addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID) + addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress) + addErrMapping(tcpip.ErrNoRoute, ErrNoRoute) + addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint) + addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound) + addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState) + addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting) + addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected) + addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable) + addErrMapping(tcpip.ErrPortInUse, ErrPortInUse) + addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress) + addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend) + addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive) + addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock) + addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused) + addErrMapping(tcpip.ErrTimeout, ErrTimeout) + addErrMapping(tcpip.ErrAborted, ErrAborted) + addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted) + addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired) + addErrMapping(tcpip.ErrNotSupported, ErrNotSupported) + addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported) + addErrMapping(tcpip.ErrNotConnected, ErrNotConnected) + addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset) + addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted) + addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile) + addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue) + addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown) + addErrMapping(tcpip.ErrBadAddress, ErrBadAddress) + addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable) + addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong) + addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace) + addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled) + addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet) + addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported) } // TranslateNetstackError converts an error from the tcpip package to a sentry @@ -95,7 +108,7 @@ func TranslateNetstackError(err *tcpip.Error) *Error { if err == nil { return nil } - se, ok := netstackErrorTranslations[err] + se, ok := netstackErrorTranslations[err.String()] if !ok { panic("Unknown error: " + err.String()) } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 454e07662..27f96a3ac 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ + "socketops.go", "tcpip.go", "time_unsafe.go", "timer.go", diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 12b061def..b196324c7 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -97,6 +97,9 @@ type testConnection struct { func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + return nil, err + } entry, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&entry, waiter.EventOut) @@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } // Give c.Read() a chance to block before closing the connection. @@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } c.SetDeadline(time.Now().Add(time.Minute)) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 6f81b0164..91971b687 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "reflect" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker { v = ip.TTL() case header.IPv6: v = ip.HopLimit() + case *ipv6HeaderWithExtHdr: + v = ip.HopLimit() + default: + t.Fatalf("unrecognized header type %T for TTL evaluation", ip) } if v != ttl { t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) @@ -197,7 +202,7 @@ func IPPayload(payload []byte) NetworkChecker { } // IPv4Options returns a checker that checks the options in an IPv4 packet. -func IPv4Options(want []byte) NetworkChecker { +func IPv4Options(want header.IPv4Options) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -216,6 +221,42 @@ func IPv4Options(want []byte) NetworkChecker { } } +// IPv4RouterAlert returns a checker that checks that the RouterAlert option is +// set in an IPv4 packet. +func IPv4RouterAlert() NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + iterator := ip.Options().MakeIterator() + for { + opt, done, err := iterator.Next() + if err != nil { + t.Fatalf("error acquiring next IPv4 option %s", err) + } + if done { + break + } + if opt.Type() != header.IPv4OptionRouterAlertType { + continue + } + want := [header.IPv4OptionRouterAlertLength]byte{ + byte(header.IPv4OptionRouterAlertType), + header.IPv4OptionRouterAlertLength, + header.IPv4OptionRouterAlertValue, + header.IPv4OptionRouterAlertValue, + } + if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { + t.Errorf("router alert option mismatch (-want +got):\n%s", diff) + } + return + } + t.Errorf("failed to find router alert option in %v", ip.Options()) + } +} + // FragmentOffset creates a checker that checks the FragmentOffset field. func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress +// field in ControlMessages. +func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasOriginalDstAddress { + t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) + } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { + t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -859,6 +913,21 @@ func ICMPv4Seq(want uint16) TransportChecker { } } +// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. +func ICMPv4Pointer(want uint8) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Pointer(); got != want { + t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) + } + } +} + // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. // This assumes that the payload exactly makes up the rest of the slice. func ICMPv4Checksum() TransportChecker { @@ -889,6 +958,12 @@ func ICMPv4Payload(want []byte) TransportChecker { t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } payload := icmpv4.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + if diff := cmp.Diff(want, payload); diff != "" { t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } @@ -953,6 +1028,112 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { } } +// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific +// field. +func ICMPv6TypeSpecific(want uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + if got := icmpv6.TypeSpecific(); got != want { + t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) + } + } +} + +// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. +func ICMPv6Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + payload := icmpv6.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) + } + } +} + +// MLD creates a checker that checks that the packet contains a valid MLD +// message for type of mldType, with potentially additional checks specified by +// checkers. +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// MLD message as far as the size of the message (minSize) is concerned. The +// values within the message are up to checkers to validate. +func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // Check normal ICMPv6 first. + ICMPv6( + ICMPv6Type(msgType), + ICMPv6Code(0))(t, h) + + last := h[len(h)-1] + + icmp := header.ICMPv6(last.Payload()) + if got := len(icmp.MessageBody()); got < minSize { + t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) + } + + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMaxRespDelay(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MaximumResponseDelay(); got != want { + t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) + } + } +} + +// MLDMulticastAddress creates a checker that checks the Multicast Address +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMulticastAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MulticastAddress(); got != want { + t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. @@ -972,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N last := h[len(h)-1] icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.NDPPayload()); got < minSize { + if got := len(icmp.MessageBody()); got < minSize { t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) } @@ -1006,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) if got := ns.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) @@ -1035,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) @@ -1053,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.SolicitedFlag(); got != want { t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) @@ -1124,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) ndpOptions(t, na.Options(), opts) } } @@ -1139,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ndpOptions(t, ns.Options(), opts) } } @@ -1164,7 +1345,261 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.NDPPayload()) + rs := header.NDPRouterSolicit(icmp.MessageBody()) ndpOptions(t, rs.Options(), opts) } } + +// IGMP checks the validity and properties of the given IGMP packet. It is +// expected to be used in conjunction with other IGMP transport checkers for +// specific properties. +func IGMP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) + } + + igmp := header.IGMP(last.Payload()) + for _, f := range checkers { + f(t, igmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// IGMPType creates a checker that checks the IGMP Type field. +func IGMPType(want header.IGMPType) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.Type(); got != want { + t.Errorf("got igmp.Type() = %d, want = %d", got, want) + } + } +} + +// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. +func IGMPMaxRespTime(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.MaxRespTime(); got != want { + t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) + } + } +} + +// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. +func IGMPGroupAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.GroupAddress(); got != want { + t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) + } + } +} + +// IPv6ExtHdrChecker is a function to check an extension header. +type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) + +// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. +func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv6 := header.IPv6(b) + if !ipv6.IsValid(len(b)) { + t.Error("not a valid IPv6 packet") + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var rawPayloadHeader header.IPv6RawPayloadHeader + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) + return + } + r, ok := h.(header.IPv6RawPayloadHeader) + if ok { + rawPayloadHeader = r + break + } + } + + networkHeader := ipv6HeaderWithExtHdr{ + IPv6: ipv6, + transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), + payload: rawPayloadHeader.Buf.ToView(), + } + + for _, checker := range checkers { + checker(t, []header.Network{&networkHeader}) + } +} + +// IPv6ExtHdr checks for the presence of extension headers. +// +// All the extension headers in headers will be checked exhaustively in the +// order provided. +func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) + if !ok { + t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), + buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), + ) + + for _, check := range headers { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) + return + } + check(t, h) + } + // Validate we consumed all headers. + // + // The next one over should be a raw payload and then iterator should + // terminate. + wantDone := false + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done != wantDone { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) + return + } + if done { + break + } + if _, ok := h.(header.IPv6RawPayloadHeader); !ok { + t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) + continue + } + wantDone = true + } + } +} + +var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) + +// ipv6HeaderWithExtHdr provides a header.Network implementation that takes +// extension headers into consideration, which is not the case with vanilla +// header.IPv6. +type ipv6HeaderWithExtHdr struct { + header.IPv6 + transport tcpip.TransportProtocolNumber + payload []byte +} + +// TransportProtocol implements header.Network. +func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { + return h.transport +} + +// Payload implements header.Network. +func (h *ipv6HeaderWithExtHdr) Payload() []byte { + return h.payload +} + +// IPv6ExtHdrOptionChecker is a function to check an extension header option. +type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) + +// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop +// extension header and validates the containing options with checkers. +// +// checkers must exhaustively contain all the expected options. +func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { + return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { + t.Helper() + + hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) + if !ok { + t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) + return + } + optionsIterator := hbh.Iter() + for _, f := range checkers { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + f(t, opt) + } + // Validate all options were consumed. + for { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if !done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + if done { + break + } + } + } +} + +// IPv6RouterAlert validates that an extension header option is the RouterAlert +// option and matches on its value. +func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + routerAlert, ok := opt.(*header.IPv6RouterAlertOption) + if !ok { + t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) + return + } + if routerAlert.Value != want { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) + } + } +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index d87797617..0bdc12d53 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -11,11 +11,13 @@ go_library( "gue.go", "icmpv4.go", "icmpv6.go", + "igmp.go", "interfaces.go", "ipv4.go", "ipv6.go", "ipv6_extension_headers.go", "ipv6_fragment.go", + "mld.go", "ndp_neighbor_advert.go", "ndp_neighbor_solicit.go", "ndp_options.go", @@ -39,6 +41,8 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "igmp_test.go", + "ipv4_test.go", "ipv6_test.go", "ipversion_test.go", "tcp_test.go", @@ -58,6 +62,7 @@ go_test( srcs = [ "eth_test.go", "ipv6_extension_headers_test.go", + "mld_test.go", "ndp_test.go", ], library = ":header", diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 504408878..2f13dea6a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -99,7 +99,8 @@ const ( // ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. const ( - ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4ReassemblyTimeout ICMPv4Code = 1 ) // ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. @@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } +// Pointer returns the pointer field in a Parameter Problem packet. +func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] } + +// SetPointer sets the pointer field in a Parameter Problem packet. +func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } + // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 4303fc5d5..2eef64b4d 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -115,6 +115,12 @@ const ( ICMPv6NeighborSolicit ICMPv6Type = 135 ICMPv6NeighborAdvert ICMPv6Type = 136 ICMPv6RedirectMsg ICMPv6Type = 137 + + // Multicast Listener Discovery (MLD) messages, see RFC 2710. + + ICMPv6MulticastListenerQuery ICMPv6Type = 130 + ICMPv6MulticastListenerReport ICMPv6Type = 131 + ICMPv6MulticastListenerDone ICMPv6Type = 132 ) // IsErrorType returns true if the receiver is an ICMP error type. @@ -245,10 +251,9 @@ func (b ICMPv6) SetSequence(sequence uint16) { binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) } -// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6 -// packet's message body as defined by RFC 4443 section 2.1; the portion of the -// ICMPv6 buffer after the first ICMPv6HeaderSize bytes. -func (b ICMPv6) NDPPayload() []byte { +// MessageBody returns the message body as defined by RFC 4443 section 2.1; the +// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes. +func (b ICMPv6) MessageBody() []byte { return b[ICMPv6HeaderSize:] } diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go new file mode 100644 index 000000000..5c5be1b9d --- /dev/null +++ b/pkg/tcpip/header/igmp.go @@ -0,0 +1,181 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// IGMP represents an IGMP header stored in a byte array. +type IGMP []byte + +// IGMP implements `Transport`. +var _ Transport = (*IGMP)(nil) + +const ( + // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes, + // as per RFC 2236, Section 2, Page 2. + IGMPMinimumSize = 8 + + // IGMPQueryMinimumSize is the minimum size of a valid Membership Query + // Message in bytes, as per RFC 2236, Section 2, Page 2. + IGMPQueryMinimumSize = 8 + + // IGMPReportMinimumSize is the minimum size of a valid Report Message in + // bytes, as per RFC 2236, Section 2, Page 2. + IGMPReportMinimumSize = 8 + + // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message + // in bytes, as per RFC 2236, Section 2, Page 2. + IGMPLeaveMessageMinimumSize = 8 + + // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page + // 3. + IGMPTTL = 1 + + // igmpTypeOffset defines the offset of the type field in an IGMP message. + igmpTypeOffset = 0 + + // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an + // IGMP message. + igmpMaxRespTimeOffset = 1 + + // igmpChecksumOffset defines the offset of the checksum field in an IGMP + // message. + igmpChecksumOffset = 2 + + // igmpGroupAddressOffset defines the offset of the Group Address field in an + // IGMP message. + igmpGroupAddressOffset = 4 + + // IGMPProtocolNumber is IGMP's transport protocol number. + IGMPProtocolNumber tcpip.TransportProtocolNumber = 2 +) + +// IGMPType is the IGMP type field as per RFC 2236. +type IGMPType byte + +// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2. +// Descriptions below come from there. +const ( + // IGMPMembershipQuery indicates that the message type is Membership Query. + // "There are two sub-types of Membership Query messages: + // - General Query, used to learn which groups have members on an + // attached network. + // - Group-Specific Query, used to learn if a particular group + // has any members on an attached network. + // These two messages are differentiated by the Group Address, as + // described in section 1.4 ." + IGMPMembershipQuery IGMPType = 0x11 + // IGMPv1MembershipReport indicates that the message is a Membership Report + // generated by a host using the IGMPv1 protocol: "an additional type of + // message, for backwards-compatibility with IGMPv1" + IGMPv1MembershipReport IGMPType = 0x12 + // IGMPv2MembershipReport indicates that the Message type is a Membership + // Report generated by a host using the IGMPv2 protocol. + IGMPv2MembershipReport IGMPType = 0x16 + // IGMPLeaveGroup indicates that the message type is a Leave Group + // notification message. + IGMPLeaveGroup IGMPType = 0x17 +) + +// Type is the IGMP type field. +func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) } + +// SetType sets the IGMP type field. +func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) } + +// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership +// Query messages, in other cases it is set to 0 by the sender and ignored by +// the receiver. +func (b IGMP) MaxRespTime() time.Duration { + // As per RFC 2236 section 2.2, + // + // The Max Response Time field is meaningful only in Membership Query + // messages, and specifies the maximum allowed time before sending a + // responding report in units of 1/10 second. In all other messages, it + // is set to zero by the sender and ignored by receivers. + return DecisecondToDuration(b[igmpMaxRespTimeOffset]) +} + +// SetMaxRespTime sets the MaxRespTimeField. +func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m } + +// Checksum is the IGMP checksum field. +func (b IGMP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[igmpChecksumOffset:]) +} + +// SetChecksum sets the IGMP checksum field. +func (b IGMP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum) +} + +// GroupAddress gets the Group Address field. +func (b IGMP) GroupAddress() tcpip.Address { + return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize]) +} + +// SetGroupAddress sets the Group Address field. +func (b IGMP) SetGroupAddress(address tcpip.Address) { + if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize)) + } +} + +// SourcePort implements Transport.SourcePort. +func (IGMP) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (IGMP) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (IGMP) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (IGMP) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (IGMP) Payload() []byte { + return nil +} + +// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP +// header. +func IGMPCalculateChecksum(h IGMP) uint16 { + // The header contains a checksum itself, set it aside to avoid checksumming + // the checksum and replace it afterwards. + existingXsum := h.Checksum() + h.SetChecksum(0) + xsum := ^Checksum(h, 0) + h.SetChecksum(existingXsum) + return xsum +} + +// DecisecondToDuration converts a value representing deci-seconds to a +// time.Duration. +func DecisecondToDuration(ds uint8) time.Duration { + return time.Duration(ds) * time.Second / 10 +} diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go new file mode 100644 index 000000000..b6126d29a --- /dev/null +++ b/pkg/tcpip/header/igmp_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestIGMPHeader tests the functions within header.igmp +func TestIGMPHeader(t *testing.T) { + const maxRespTimeTenthSec = 0xF0 + b := []byte{ + 0x11, // IGMP Type, Membership Query + maxRespTimeTenthSec, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) + } + + if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) + } + + igmpType := header.IGMPv2MembershipReport + igmpHeader.SetType(igmpType) + if got := igmpHeader.Type(); got != igmpType { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType) + } + if got := header.IGMPType(b[0]); got != igmpType { + t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType) + } + + respTime := byte(0x02) + igmpHeader.SetMaxRespTime(respTime) + if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) + } + + checksum := uint16(0x0102) + igmpHeader.SetChecksum(checksum) + if got := igmpHeader.Checksum(); got != checksum { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) + } + + groupAddress := tcpip.Address("\x04\x03\x02\x01") + igmpHeader.SetGroupAddress(groupAddress) + if got := igmpHeader.GroupAddress(); got != groupAddress { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) + } +} + +// TestIGMPChecksum ensures that the checksum calculator produces the expected +// checksum. +func TestIGMPChecksum(t *testing.T) { + b := []byte{ + 0x11, // IGMP Type, Membership Query + 0xF0, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + // Calculate the initial checksum after setting the checksum temporarily to 0 + // to avoid checksumming the checksum. + initialChecksum := igmpHeader.Checksum() + igmpHeader.SetChecksum(0) + checksum := ^header.Checksum(b, 0) + igmpHeader.SetChecksum(initialChecksum) + + if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum { + t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum) + } +} + +func TestDecisecondToDuration(t *testing.T) { + const valueInDeciseconds = 5 + if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want { + t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want) + } +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 4c6e4be64..e6103f4bc 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,7 +39,6 @@ import ( // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Options | Padding | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// const ( versIHL = 0 tos = 1 @@ -56,12 +56,9 @@ const ( ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the -// fields of a packet that needs to be encoded. +// fields of a packet that needs to be encoded. The IHL field is not here as +// it is totally defined by the size of the options. type IPv4Fields struct { - // IHL is the "internet header length" field of an IPv4 packet. The value - // is in bytes. - IHL uint8 - // TOS is the "type of service" field of an IPv4 packet. TOS uint8 @@ -91,9 +88,22 @@ type IPv4Fields struct { // DstAddr is the "destination ip address" of an IPv4 packet. DstAddr tcpip.Address + + // Options must be 40 bytes or less as they must fit along with the + // rest of the IPv4 header into the maximum size describable in the + // IHL field. RFC 791 section 3.1 says: + // IHL: 4 bits + // + // Internet Header Length is the length of the internet header in 32 + // bit words, and thus points to the beginning of the data. Note that + // the minimum value for a correct header is 5. + // + // That leaves ten 32 bit (4 byte) fields for options. An attempt to encode + // more will fail. + Options IPv4OptionsSerializer } -// IPv4 represents an ipv4 header stored in a byte array. +// IPv4 is an IPv4 header. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. // Always call IsValid() to validate an instance of IPv4 before using other @@ -106,16 +116,19 @@ const ( IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given - // that there are only 4 bits to represents the header length in 32-bit - // units, the header cannot exceed 15*4 = 60 bytes. + // that there are only 4 bits (max 0xF (15)) to represent the header length + // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumOptionsSize is the largest size the IPv4 options can be. + IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. // // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 // header size). But RFC 791 section 3.2 discusses the design of the IPv4 // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of - // 65,536 octets. Note that this is consistent with the the datagram total + // 65,536 octets. Note that this is consistent with the datagram total // length field (of course, the header is counted in the total length and not // in the fragments)." IPv4MaximumPayloadSize = 65536 @@ -130,7 +143,7 @@ const ( // IPv4ProtocolNumber is IPv4's network protocol number. IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 - // IPv4Version is the version of the ipv4 protocol. + // IPv4Version is the version of the IPv4 protocol. IPv4Version = 4 // IPv4AllSystems is the all systems IPv4 multicast address as per @@ -144,10 +157,20 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + // IPv4AllRoutersGroup is a multicast address for all routers. + IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02" + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP // packet that every IPv4 capable host must be able to // process/reassemble. IPv4MinimumProcessableDatagramSize = 576 + + // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791, + // section 3.2: + // Every internet module must be able to forward a datagram of 68 octets + // without further fragmentation. This is because an internet header may be + // up to 60 octets, and the minimum fragment is 8 octets. + IPv4MinimumMTU = 68 ) // Flags that may be set in an IPv4 packet. @@ -191,14 +214,13 @@ func IPVersion(b []byte) int { // Internet Header Length is the length of the internet header in 32 // bit words, and thus points to the beginning of the data. Note that // the minimum value for a correct header is 5. -// const ( ipVersionShift = 4 ipIHLMask = 0x0f IPv4IHLStride = 4 ) -// HeaderLength returns the value of the "header length" field of the ipv4 +// HeaderLength returns the value of the "header length" field of the IPv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { return (b[versIHL] & ipIHLMask) * IPv4IHLStride @@ -212,17 +234,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) { b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } -// ID returns the value of the identifier field of the ipv4 header. +// ID returns the value of the identifier field of the IPv4 header. func (b IPv4) ID() uint16 { return binary.BigEndian.Uint16(b[id:]) } -// Protocol returns the value of the protocol field of the ipv4 header. +// Protocol returns the value of the protocol field of the IPv4 header. func (b IPv4) Protocol() uint8 { return b[protocol] } -// Flags returns the "flags" field of the ipv4 header. +// Flags returns the "flags" field of the IPv4 header. func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } @@ -232,41 +254,52 @@ func (b IPv4) More() bool { return b.Flags()&IPv4FlagMoreFragments != 0 } -// TTL returns the "TTL" field of the ipv4 header. +// TTL returns the "TTL" field of the IPv4 header. func (b IPv4) TTL() uint8 { return b[ttl] } -// FragmentOffset returns the "fragment offset" field of the ipv4 header. +// FragmentOffset returns the "fragment offset" field of the IPv4 header. func (b IPv4) FragmentOffset() uint16 { return binary.BigEndian.Uint16(b[flagsFO:]) << 3 } -// TotalLength returns the "total length" field of the ipv4 header. +// TotalLength returns the "total length" field of the IPv4 header. func (b IPv4) TotalLength() uint16 { return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } -// Checksum returns the checksum field of the ipv4 header. +// Checksum returns the checksum field of the IPv4 header. func (b IPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[checksum:]) } -// SourceAddress returns the "source address" field of the ipv4 header. +// SourceAddress returns the "source address" field of the IPv4 header. func (b IPv4) SourceAddress() tcpip.Address { return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize]) } -// DestinationAddress returns the "destination address" field of the ipv4 +// DestinationAddress returns the "destination address" field of the IPv4 // header. func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } -// Options returns a a buffer holding the options. -func (b IPv4) Options() []byte { +// padIPv4OptionsLength returns the total length for IPv4 options of length l +// after applying padding according to RFC 791: +// The internet header padding is used to ensure that the internet +// header ends on a 32 bit boundary. +func padIPv4OptionsLength(length uint8) uint8 { + return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1) +} + +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + +// Options returns a buffer holding the options. +func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() - return b[options:hdrLen:hdrLen] + return IPv4Options(b[options:hdrLen:hdrLen]) } // TransportProtocol implements Network.TransportProtocol. @@ -279,17 +312,17 @@ func (b IPv4) Payload() []byte { return b[b.HeaderLength():][:b.PayloadLength()] } -// PayloadLength returns the length of the payload portion of the ipv4 packet. +// PayloadLength returns the length of the payload portion of the IPv4 packet. func (b IPv4) PayloadLength() uint16 { return b.TotalLength() - uint16(b.HeaderLength()) } -// TOS returns the "type of service" field of the ipv4 header. +// TOS returns the "type of service" field of the IPv4 header. func (b IPv4) TOS() (uint8, uint32) { return b[tos], 0 } -// SetTOS sets the "type of service" field of the ipv4 header. +// SetTOS sets the "type of service" field of the IPv4 header. func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } @@ -299,18 +332,18 @@ func (b IPv4) SetTTL(v byte) { b[ttl] = v } -// SetTotalLength sets the "total length" field of the ipv4 header. +// SetTotalLength sets the "total length" field of the IPv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } -// SetChecksum sets the checksum field of the ipv4 header. +// SetChecksum sets the checksum field of the IPv4 header. func (b IPv4) SetChecksum(v uint16) { binary.BigEndian.PutUint16(b[checksum:], v) } // SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the -// ipv4 header. +// IPv4 header. func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { v := (uint16(flags) << 13) | (offset >> 3) binary.BigEndian.PutUint16(b[flagsFO:], v) @@ -321,25 +354,36 @@ func (b IPv4) SetID(v uint16) { binary.BigEndian.PutUint16(b[id:], v) } -// SetSourceAddress sets the "source address" field of the ipv4 header. +// SetSourceAddress sets the "source address" field of the IPv4 header. func (b IPv4) SetSourceAddress(addr tcpip.Address) { copy(b[srcAddr:srcAddr+IPv4AddressSize], addr) } -// SetDestinationAddress sets the "destination address" field of the ipv4 +// SetDestinationAddress sets the "destination address" field of the IPv4 // header. func (b IPv4) SetDestinationAddress(addr tcpip.Address) { copy(b[dstAddr:dstAddr+IPv4AddressSize], addr) } -// CalculateChecksum calculates the checksum of the ipv4 header. +// CalculateChecksum calculates the checksum of the IPv4 header. func (b IPv4) CalculateChecksum() uint16 { return Checksum(b[:b.HeaderLength()], 0) } -// Encode encodes all the fields of the ipv4 header. +// Encode encodes all the fields of the IPv4 header. func (b IPv4) Encode(i *IPv4Fields) { - b.SetHeaderLength(i.IHL) + // The size of the options defines the size of the whole header and thus the + // IHL field. Options are rare and this is a heavily used function so it is + // worth a bit of optimisation here to keep the serializer out of the fast + // path. + hdrLen := uint8(IPv4MinimumSize) + if len(i.Options) != 0 { + hdrLen += i.Options.Serialize(b[options:]) + } + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize)) + } + b.SetHeaderLength(hdrLen) b[tos] = i.TOS b.SetTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) @@ -351,7 +395,7 @@ func (b IPv4) Encode(i *IPv4Fields) { copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) } -// EncodePartial updates the total length and checksum fields of ipv4 header, +// EncodePartial updates the total length and checksum fields of IPv4 header, // taking in the partial checksum, which is the checksum of the header without // the total length and checksum fields. It is useful in cases when similar // packets are produced. @@ -398,3 +442,587 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool { } return addr[0] == 0x7f } + +// ========================= Options ========================== + +// An IPv4OptionType can hold the valuse for the Type in an IPv4 option. +type IPv4OptionType byte + +// These constants are needed to identify individual options in the option list. +// While RFC 791 (page 31) says "Every internet module must be able to act on +// every option." This has not generally been adhered to and some options have +// very low rates of support. We do not support options other than those shown +// below. + +const ( + // IPv4OptionListEndType is the option type for the End Of Option List + // option. Anything following is ignored. + IPv4OptionListEndType IPv4OptionType = 0 + + // IPv4OptionNOPType is the No-Operation option. May appear between other + // options and may appear multiple times. + IPv4OptionNOPType IPv4OptionType = 1 + + // IPv4OptionRouterAlertType is the option type for the Router Alert option, + // defined in RFC 2113 Section 2.1. + IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80 + + // IPv4OptionRecordRouteType is used by each router on the path of the packet + // to record its path. It is carried over to an Echo Reply. + IPv4OptionRecordRouteType IPv4OptionType = 7 + + // IPv4OptionTimestampType is the option type for the Timestamp option. + IPv4OptionTimestampType IPv4OptionType = 68 + + // ipv4OptionTypeOffset is the offset in an option of its type field. + ipv4OptionTypeOffset = 0 + + // IPv4OptionLengthOffset is the offset in an option of its length field. + IPv4OptionLengthOffset = 1 +) + +// Potential errors when parsing generic IP options. +var ( + ErrIPv4OptZeroLength = errors.New("zero length IP option") + ErrIPv4OptDuplicate = errors.New("duplicate IP option") + ErrIPv4OptInvalid = errors.New("invalid IP option") + ErrIPv4OptMalformed = errors.New("malformed IP option") + ErrIPv4OptionTruncated = errors.New("truncated IP option") + ErrIPv4OptionAddress = errors.New("bad IP option address") +) + +// IPv4Option is an interface representing various option types. +type IPv4Option interface { + // Type returns the type identifier of the option. + Type() IPv4OptionType + + // Size returns the size of the option in bytes. + Size() uint8 + + // Contents returns a slice holding the contents of the option. + Contents() []byte +} + +var _ IPv4Option = (*IPv4OptionGeneric)(nil) + +// IPv4OptionGeneric is an IPv4 Option of unknown type. +type IPv4OptionGeneric []byte + +// Type implements IPv4Option. +func (o *IPv4OptionGeneric) Type() IPv4OptionType { + return IPv4OptionType((*o)[ipv4OptionTypeOffset]) +} + +// Size implements IPv4Option. +func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } + +// Contents implements IPv4Option. +func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } + +// IPv4OptionIterator is an iterator pointing to a specific IP option +// at any point of time. It also holds information as to a new options buffer +// that we are building up to hand back to the caller. +type IPv4OptionIterator struct { + options IPv4Options + // ErrCursor is where we are while parsing options. It is exported as any + // resulting ICMP packet is supposed to have a pointer to the byte within + // the IP packet where the error was detected. + ErrCursor uint8 + nextErrCursor uint8 + newOptions [IPv4MaximumOptionsSize]byte + writePoint int +} + +// MakeIterator sets up and returns an iterator of options. It also sets up the +// building of a new option set. +func (o IPv4Options) MakeIterator() IPv4OptionIterator { + return IPv4OptionIterator{ + options: o, + nextErrCursor: IPv4MinimumSize, + } +} + +// RemainingBuffer returns the remaining (unused) part of the new option buffer, +// into which a new option may be written. +func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { + return IPv4Options(i.newOptions[i.writePoint:]) +} + +// ConsumeBuffer marks a portion of the new buffer as used. +func (i *IPv4OptionIterator) ConsumeBuffer(size int) { + i.writePoint += size +} + +// PushNOPOrEnd puts one of the single byte options onto the new options. +// Only values 0 or 1 (ListEnd or NOP) are valid input. +func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) { + if val > IPv4OptionNOPType { + panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val)) + } + i.newOptions[i.writePoint] = byte(val) + i.writePoint++ +} + +// Finalize returns the completed replacement options buffer padded +// as needed. +func (i *IPv4OptionIterator) Finalize() IPv4Options { + // RFC 791 page 31 says: + // The options might not end on a 32-bit boundary. The internet header + // must be filled out with octets of zeros. The first of these would + // be interpreted as the end-of-options option, and the remainder as + // internet header padding. + // Since the buffer is already zero filled we just need to step the write + // pointer up to the next multiple of 4. + options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3]) + // Poison the write pointer. + i.writePoint = len(i.newOptions) + return options +} + +// Next returns the next IP option in the buffer/list of IP options. +// It returns +// - A slice of bytes holding the next option or nil if there is error. +// - A boolean which is true if parsing of all the options is complete. +// - An error which is non-nil if an error condition was encountered. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { + // The opts slice gets shorter as we process the options. When we have no + // bytes left we are done. + if len(i.options) == 0 { + return nil, true, nil + } + + i.ErrCursor = i.nextErrCursor + + optType := IPv4OptionType(i.options[ipv4OptionTypeOffset]) + + if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType { + optionBody := i.options[:1] + i.options = i.options[1:] + i.nextErrCursor = i.ErrCursor + 1 + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil + } + + // There are no more single byte options defined. All the rest have a length + // field so we need to sanity check it. + if len(i.options) == 1 { + return nil, true, ErrIPv4OptMalformed + } + + optLen := i.options[IPv4OptionLengthOffset] + + if optLen == 0 { + i.ErrCursor++ + return nil, true, ErrIPv4OptZeroLength + } + + if optLen == 1 { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + + if optLen > uint8(len(i.options)) { + i.ErrCursor++ + return nil, true, ErrIPv4OptionTruncated + } + + optionBody := i.options[:optLen] + i.nextErrCursor = i.ErrCursor + optLen + i.options = i.options[optLen:] + + // Check the length of some option types that we know. + switch optType { + case IPv4OptionTimestampType: + if optLen < IPv4OptionTimestampHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionTimestamp(optionBody) + return &retval, false, nil + + case IPv4OptionRecordRouteType: + if optLen < IPv4OptionRecordRouteHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionRecordRoute(optionBody) + return &retval, false, nil + } + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil +} + +// +// IP Timestamp option - RFC 791 page 22. +// +--------+--------+--------+--------+ +// |01000100| length | pointer|oflw|flg| +// +--------+--------+--------+--------+ +// | internet address | +// +--------+--------+--------+--------+ +// | timestamp | +// +--------+--------+--------+--------+ +// | ... | +// +// Type = 68 +// +// The Option Length is the number of octets in the option counting +// the type, length, pointer, and overflow/flag octets (maximum +// length 40). +// +// The Pointer is the number of octets from the beginning of this +// option to the end of timestamps plus one (i.e., it points to the +// octet beginning the space for next timestamp). The smallest +// legal value is 5. The timestamp area is full when the pointer +// is greater than the length. +// +// The Overflow (oflw) [4 bits] is the number of IP modules that +// cannot register timestamps due to lack of space. +// +// The Flag (flg) [4 bits] values are +// +// 0 -- time stamps only, stored in consecutive 32-bit words, +// +// 1 -- each timestamp is preceded with internet address of the +// registering entity, +// +// 3 -- the internet address fields are prespecified. An IP +// module only registers its timestamp if it matches its own +// address with the next specified internet address. +// +// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC. +// +// The Timestamp is a right-justified, 32-bit timestamp in +// milliseconds since midnight UT. If the time is not available in +// milliseconds or cannot be provided with respect to midnight UT +// then any time may be inserted as a timestamp provided the high +// order bit of the timestamp field is set to one to indicate the +// use of a non-standard value. + +// IPv4OptTSFlags sefines the values expected in the Timestamp +// option Flags field. +type IPv4OptTSFlags uint8 + +// +// Timestamp option specific related constants. +const ( + // IPv4OptionTimestampHdrLength is the length of the timestamp option header. + IPv4OptionTimestampHdrLength = 4 + + // IPv4OptionTimestampSize is the size of an IP timestamp. + IPv4OptionTimestampSize = 4 + + // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address. + IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize + + // IPv4OptionTimestampMaxSize is limited by space for options + IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize + + // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp + // is present. + IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0 + + // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and + // IP are present. + IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1 + + // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that + // predefined IP is present. + IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3 +) + +// ipv4TimestampTime provides the current time as specified in RFC 791. +func ipv4TimestampTime(clock tcpip.Clock) uint32 { + const millisecondsPerDay = 24 * 3600 * 1000 + const nanoPerMilli = 1000000 + return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) +} + +// IP Timestamp option fields. +const ( + // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field. + IPv4OptTSPointerOffset = 2 + + // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow + // fields, (each being 4 bits). + IPv4OptTSOFLWAndFLGOffset = 3 + // These constants define the sub byte fields of the Flag and OverFlow field. + ipv4OptionTimestampOverflowshift = 4 + ipv4OptionTimestampFlagsMask byte = 0x0f +) + +var _ IPv4Option = (*IPv4OptionTimestamp)(nil) + +// IPv4OptionTimestamp is a Timestamp option from RFC 791. +type IPv4OptionTimestamp []byte + +// Type implements IPv4Option.Type(). +func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType } + +// Size implements IPv4Option. +func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } + +// Contents implements IPv4Option. +func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } + +// Pointer returns the pointer field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Pointer() uint8 { + return (*ts)[IPv4OptTSPointerOffset] +} + +// Flags returns the flags field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags { + return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask) +} + +// Overflow returns the Overflow field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Overflow() uint8 { + return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift +} + +// IncOverflow increments the Overflow field in the IP Timestamp option. It +// returns the incremented value. If the return value is 0 then the field +// overflowed. +func (ts *IPv4OptionTimestamp) IncOverflow() uint8 { + (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift + return ts.Overflow() +} + +// UpdateTimestamp updates the fields of the next free timestamp slot. +func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) { + slot := (*ts)[ts.Pointer()-1:] + + switch ts.Flags() { + case IPv4OptionTimestampOnlyFlag: + binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize + case IPv4OptionTimestampWithIPFlag: + if n := copy(slot, addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + case IPv4OptionTimestampWithPredefinedIPFlag: + if tcpip.Address(slot[:IPv4AddressSize]) == addr { + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + } + } +} + +// RecordRoute option specific related constants. +// +// from RFC 791 page 20: +// Record Route +// +// +--------+--------+--------+---------//--------+ +// |00000111| length | pointer| route data | +// +--------+--------+--------+---------//--------+ +// Type=7 +// +// The record route option provides a means to record the route of +// an internet datagram. +// +// The option begins with the option type code. The second octet +// is the option length which includes the option type code and the +// length octet, the pointer octet, and length-3 octets of route +// data. The third octet is the pointer into the route data +// indicating the octet which begins the next area to store a route +// address. The pointer is relative to this option, and the +// smallest legal value for the pointer is 4. +const ( + // IPv4OptionRecordRouteHdrLength is the length of the Record Route option + // header. + IPv4OptionRecordRouteHdrLength = 3 + + // IPv4OptRRPointerOffset is the offset to the pointer field in an RR + // option, which points to the next free slot in the list of addresses. + IPv4OptRRPointerOffset = 2 +) + +var _ IPv4Option = (*IPv4OptionRecordRoute)(nil) + +// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791. +type IPv4OptionRecordRoute []byte + +// Pointer returns the pointer field in the IP RecordRoute option. +func (rr *IPv4OptionRecordRoute) Pointer() uint8 { + return (*rr)[IPv4OptRRPointerOffset] +} + +// StoreAddress stores the given IPv4 address into the next free slot. +func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) { + start := rr.Pointer() - 1 // A one based number. + // start and room checked by caller. + if n := copy((*rr)[start:], addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize +} + +// Type implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType } + +// Size implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } + +// Contents implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } + +// Router Alert option specific related constants. +// +// from RFC 2113 section 2.1: +// +// +--------+--------+--------+--------+ +// |10010100|00000100| 2 octet value | +// +--------+--------+--------+--------+ +// +// Type: +// Copied flag: 1 (all fragments must carry the option) +// Option class: 0 (control) +// Option number: 20 (decimal) +// +// Length: 4 +// +// Value: A two octet code with the following values: +// 0 - Router shall examine packet +// 1-65535 - Reserved +const ( + // IPv4OptionRouterAlertLength is the length of a Router Alert option. + IPv4OptionRouterAlertLength = 4 + + // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit + // payload of the router alert option. + IPv4OptionRouterAlertValue = 0 + + // iPv4OptionRouterAlertValueOffset is the offset for the value of a + // RouterAlert option. + iPv4OptionRouterAlertValueOffset = 2 +) + +// IPv4SerializableOption is an interface to represent serializable IPv4 option +// types. +type IPv4SerializableOption interface { + // optionType returns the type identifier of the option. + optionType() IPv4OptionType +} + +// IPv4SerializableOptionPayload is an interface providing serialization of the +// payload of an IPv4 option. +type IPv4SerializableOptionPayload interface { + // length returns the size of the payload. + length() uint8 + + // serializeInto serializes the payload into the provided byte buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // Length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MUST panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto(buffer []byte) uint8 +} + +// IPv4OptionsSerializer is a serializer for IPv4 options. +type IPv4OptionsSerializer []IPv4SerializableOption + +// Length returns the total number of bytes required to serialize the options. +func (s IPv4OptionsSerializer) Length() uint8 { + var total uint8 + for _, opt := range s { + total++ + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Add 1 to reported length to account for the length byte. + total += 1 + withPayload.length() + } + } + return padIPv4OptionsLength(total) +} + +// Serialize serializes the provided list of IPV4 options into b. +// +// Note, b must be of sufficient size to hold all the options in s. See +// IPv4OptionsSerializer.Length for details on the getting the total size +// of a serialized IPv4OptionsSerializer. +// +// Serialize panics if b is not of sufficient size to hold all the options in s. +func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 { + var total uint8 + for _, opt := range s { + ty := opt.optionType() + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Serialize first to reduce bounds checks. + l := 2 + withPayload.serializeInto(b[2:]) + b[0] = byte(ty) + b[1] = l + b = b[l:] + total += l + continue + } + // Options without payload consist only of the type field. + // + // NB: Repeating code from the branch above is intentional to minimize + // bounds checks. + b[0] = byte(ty) + b = b[1:] + total++ + } + + // According to RFC 791: + // + // The internet header padding is used to ensure that the internet + // header ends on a 32 bit boundary. The padding is zero. + padded := padIPv4OptionsLength(total) + b = b[:padded-total] + for i := range b { + b[i] = 0 + } + return padded +} + +var _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil) +var _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil) + +// IPv4SerializableRouterAlertOption provides serialization of the Router Alert +// IPv4 option according to RFC 2113. +type IPv4SerializableRouterAlertOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType { + return IPv4OptionRouterAlertType +} + +// Length implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) length() uint8 { + return IPv4OptionRouterAlertLength - iPv4OptionRouterAlertValueOffset +} + +// SerializeInto implements IPv4SerializableOption. +func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 { + binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue) + return o.length() +} + +var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil) + +// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option. +type IPv4SerializableNOPOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableNOPOption) optionType() IPv4OptionType { + return IPv4OptionNOPType +} + +var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil) + +// IPv4SerializableListEndOption provides serialization for the IPv4 List End +// option. +type IPv4SerializableListEndOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableListEndOption) optionType() IPv4OptionType { + return IPv4OptionListEndType +} diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go new file mode 100644 index 000000000..6475cd694 --- /dev/null +++ b/pkg/tcpip/header/ipv4_test.go @@ -0,0 +1,179 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestIPv4OptionsSerializer(t *testing.T) { + optCases := []struct { + name string + option []header.IPv4SerializableOption + expect []byte + }{ + { + name: "NOP", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableNOPOption{}, + }, + expect: []byte{1, 0, 0, 0}, + }, + { + name: "ListEnd", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableListEndOption{}, + }, + expect: []byte{0, 0, 0, 0}, + }, + { + name: "RouterAlert", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableRouterAlertOption{}, + }, + expect: []byte{148, 4, 0, 0}, + }, { + name: "NOP and RouterAlert", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableNOPOption{}, + &header.IPv4SerializableRouterAlertOption{}, + }, + expect: []byte{1, 148, 4, 0, 0, 0, 0, 0}, + }, + } + + for _, opt := range optCases { + t.Run(opt.name, func(t *testing.T) { + s := header.IPv4OptionsSerializer(opt.option) + l := s.Length() + if got := len(opt.expect); got != int(l) { + t.Fatalf("s.Length() = %d, want = %d", got, l) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with full bytes to ensure padding is being set + // correctly. + b[i] = 0xFF + } + if serializedLength := s.Serialize(b); serializedLength != l { + t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l) + } + if diff := cmp.Diff(opt.expect, b); diff != "" { + t.Errorf("mismatched serialized option (-want +got):\n%s", diff) + } + }) + } +} + +// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested +// fields when options are supplied. +func TestIPv4EncodeOptions(t *testing.T) { + tests := []struct { + name string + numberOfNops int + encodedOptions header.IPv4Options // reply should look like this + wantIHL int + }{ + { + name: "valid no options", + wantIHL: header.IPv4MinimumSize, + }, + { + name: "one byte options", + numberOfNops: 1, + encodedOptions: header.IPv4Options{1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "two byte options", + numberOfNops: 2, + encodedOptions: header.IPv4Options{1, 1, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "three byte options", + numberOfNops: 3, + encodedOptions: header.IPv4Options{1, 1, 1, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "four byte options", + numberOfNops: 4, + encodedOptions: header.IPv4Options{1, 1, 1, 1}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "five byte options", + numberOfNops: 5, + encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 8, + }, + { + name: "thirty nine byte options", + numberOfNops: 39, + encodedOptions: header.IPv4Options{ + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, + }, + wantIHL: header.IPv4MinimumSize + 40, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops)) + for i := range serializeOpts { + serializeOpts[i] = &header.IPv4SerializableNOPOption{} + } + paddedOptionLength := serializeOpts.Length() + ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength) + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength) + hdr := buffer.NewPrependable(int(totalLen)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + // To check the padding works, poison the last byte of the options space. + if paddedOptionLength != serializeOpts.Length() { + ip.SetHeaderLength(uint8(ipHeaderLength)) + ip.Options()[paddedOptionLength-1] = 0xff + ip.SetHeaderLength(0) + } + ip.Encode(&header.IPv4Fields{ + Options: serializeOpts, + }) + options := ip.Options() + wantOptions := test.encodedOptions + if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { + t.Errorf("got IHL of %d, want %d", got, want) + } + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(wantOptions) == 0 && len(options) == 0 { + return + } + + if diff := cmp.Diff(wantOptions, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index c5d8a3456..d522e5f10 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -48,13 +48,15 @@ type IPv6Fields struct { // FlowLabel is the "flow label" field of an IPv6 packet. FlowLabel uint32 - // PayloadLength is the "payload length" field of an IPv6 packet. + // PayloadLength is the "payload length" field of an IPv6 packet, including + // the length of all extension headers. PayloadLength uint16 - // NextHeader is the "next header" field of an IPv6 packet. - NextHeader uint8 + // TransportProtocol is the transport layer protocol number. Serialized in the + // last "next header" field of the IPv6 header + extension headers. + TransportProtocol tcpip.TransportProtocolNumber - // HopLimit is the "hop limit" field of an IPv6 packet. + // HopLimit is the "Hop Limit" field of an IPv6 packet. HopLimit uint8 // SrcAddr is the "source ip address" of an IPv6 packet. @@ -62,6 +64,9 @@ type IPv6Fields struct { // DstAddr is the "destination ip address" of an IPv6 packet. DstAddr tcpip.Address + + // ExtensionHeaders are the extension headers following the IPv6 header. + ExtensionHeaders IPv6ExtHdrSerializer } // IPv6 represents an ipv6 header stored in a byte array. @@ -101,8 +106,10 @@ const ( // The address is ff02::2. IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, - // section 5. + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, + // section 5: + // IPv6 requires that every link in the Internet have an MTU of 1280 octets + // or greater. This is known as the IPv6 minimum link MTU. IPv6MinimumMTU = 1280 // IPv6Loopback is the IPv6 Loopback address. @@ -169,7 +176,7 @@ func (b IPv6) PayloadLength() uint16 { return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) } -// HopLimit returns the value of the "hop limit" field of the ipv6 header. +// HopLimit returns the value of the "Hop Limit" field of the ipv6 header. func (b IPv6) HopLimit() uint8 { return b[hopLimit] } @@ -234,6 +241,11 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) { copy(b[v6DstAddr:][:IPv6AddressSize], addr) } +// SetHopLimit sets the value of the "Hop Limit" field. +func (b IPv6) SetHopLimit(v uint8) { + b[hopLimit] = v +} + // SetNextHeader sets the value of the "next header" field of the ipv6 header. func (b IPv6) SetNextHeader(v uint8) { b[IPv6NextHeaderOffset] = v @@ -246,12 +258,14 @@ func (IPv6) SetChecksum(uint16) { // Encode encodes all the fields of the ipv6 header. func (b IPv6) Encode(i *IPv6Fields) { + extHdr := b[IPv6MinimumSize:] b.SetTOS(i.TrafficClass, i.FlowLabel) b.SetPayloadLength(i.PayloadLength) - b[IPv6NextHeaderOffset] = i.NextHeader b[hopLimit] = i.HopLimit b.SetSourceAddress(i.SrcAddr) b.SetDestinationAddress(i.DstAddr) + nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr) + b[IPv6NextHeaderOffset] = nextHeader } // IsValid performs basic validation on the packet. @@ -373,6 +387,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback +// address. +func IsV6LoopbackAddress(addr tcpip.Address) bool { + return addr == IPv6Loopback +} + // IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 // link-local multicast address. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 583c2c5d3..1fbb2cc98 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -20,7 +20,9 @@ import ( "encoding/binary" "fmt" "io" + "math" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -47,6 +49,11 @@ const ( // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end // of an IPv6 payload, as per RFC 8200 section 4.7. IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59 + + // IPv6UnknownExtHdrIdentifier is reserved by IANA. + // https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header + // "254 Use for experimentation and testing [RFC3692][RFC4727]" + IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254 ) const ( @@ -70,8 +77,8 @@ const ( // Fragment Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetOffset = 0 - // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to - // discard from the Fragment Offset. + // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment + // Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetShift = 3 // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an @@ -109,6 +116,37 @@ const ( IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 ) +// padIPv6OptionsLength returns the total length for IPv6 options of length l +// considering the 8-octet alignment as stated in RFC 8200 Section 4.2. +func padIPv6OptionsLength(length int) int { + return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1) +} + +// padIPv6Option fills b with the appropriate padding options depending on its +// length. +func padIPv6Option(b []byte) { + switch len(b) { + case 0: // No padding needed. + case 1: // Pad with Pad1. + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier) + default: // Pad with PadN. + s := b[ipv6ExtHdrOptionPayloadOffset:] + for i := range s { + s[i] = 0 + } + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier) + b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s)) + } +} + +// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to +// serialize an option at headerOffset with alignment requirements +// [align]n + alignOffset. +func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int { + padLen := headerOffset - alignOffset + return ((padLen + align - 1) & ^(align - 1)) - padLen +} + // IPv6PayloadHeader is implemented by the various headers that can be found // in an IPv6 payload. // @@ -201,29 +239,51 @@ type IPv6ExtHdrOption interface { isIPv6ExtHdrOption() } -// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier. -type IPv6ExtHdrOptionIndentifier uint8 +// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier. +type IPv6ExtHdrOptionIdentifier uint8 const ( // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that // provides 1 byte padding, as outlined in RFC 8200 section 4.2. - ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0 + ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0 // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that // provides variable length byte padding, as outlined in RFC 8200 section 4.2. - ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1 + ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1 + + // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router + // Alert Hop by Hop option as defined in RFC 2711 section 2.1. + ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5 + + // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header + // option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionTypeOffset = 0 + + // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionLengthOffset = 1 + + // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionPayloadOffset = 2 ) +// ipv6UnknownActionFromIdentifier maps an extension header option's +// identifier's high bits to the action to take when the identifier is unknown. +func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction { + return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) +} + // IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension // header option that is unknown by the parsing utilities. type IPv6UnknownExtHdrOption struct { - Identifier IPv6ExtHdrOptionIndentifier + Identifier IPv6ExtHdrOptionIdentifier Data []byte } // UnknownAction implements IPv6OptionUnknownAction.UnknownAction. func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { - return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) + return ipv6UnknownActionFromIdentifier(o.Identifier) } // isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. @@ -246,7 +306,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // options buffer has been exhausted and we are done iterating. return nil, true, nil } - id := IPv6ExtHdrOptionIndentifier(temp) + id := IPv6ExtHdrOptionIdentifier(temp) // If the option identifier indicates the option is a Pad1 option, then we // know the option does not have Length and Data fields. End processing of @@ -289,6 +349,14 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } continue + case ipv6RouterAlertHopByHopOptionIdentifier: + var routerAlertValue [ipv6RouterAlertPayloadLength]byte + if n, err := i.reader.Read(routerAlertValue[:]); err != nil { + panic(fmt.Sprintf("error when reading RouterAlert option's data bytes: %s", err)) + } else if n != ipv6RouterAlertPayloadLength { + return nil, true, fmt.Errorf("read %d bytes for RouterAlert option, expected %d", n, ipv6RouterAlertPayloadLength) + } + return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil default: bytes := make([]byte, length) if n, err := io.ReadFull(&i.reader, bytes); err != nil { @@ -452,9 +520,11 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Since we consume the iterator, we return the payload as is. buf = i.payload - // Mark i as done. + // Mark i as done, but keep track of where we were for error reporting. *i = IPv6PayloadIterator{ nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + headerOffset: i.headerOffset, + nextOffset: i.nextOffset, } } else { buf = i.payload.Clone(nil) @@ -602,3 +672,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil } + +// IPv6SerializableExtHdr provides serialization for IPv6 extension +// headers. +type IPv6SerializableExtHdr interface { + // identifier returns the assigned IPv6 header identifier for this extension + // header. + identifier() IPv6ExtensionHeaderIdentifier + + // length returns the total serialized length in bytes of this extension + // header, including the common next header and length fields. + length() int + + // serializeInto serializes the receiver into the provided byte + // buffer and with the provided nextHeader value. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto returns the number of bytes that was used to serialize the + // receiver. Implementers must only use the number of bytes required to + // serialize the receiver. Callers MAY provide a larger buffer than required + // to serialize into. + serializeInto(nextHeader uint8, b []byte) int +} + +var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil) + +// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop +// options extension header. +type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption + +const ( + // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field + // in a hop by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrNextHeaderOffset = 0 + + // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop + // by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrLengthOffset = 1 + + // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by + // hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrOptionsOffset = 2 + + // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet + // words in a hop by hop extension header's length field, as stated in RFC + // 8200 section 4.3: + // Length of the Hop-by-Hop Options header in 8-octet units, + // not including the first 8 octets. + ipv6HopByHopExtHdrUnaccountedLenWords = 1 +) + +// identifier implements IPv6SerializableExtHdr. +func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6HopByHopOptionsExtHdrIdentifier +} + +// length implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) length() int { + var total int + for _, opt := range h { + align, alignOffset := opt.alignment() + total += ipv6OptionsAlignmentPadding(total, align, alignOffset) + total += ipv6ExtHdrOptionPayloadOffset + int(opt.length()) + } + // Account for next header and total length fields and add padding. + return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total) +} + +// serializeInto implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int { + optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:] + totalLength := ipv6HopByHopExtHdrOptionsOffset + for _, opt := range h { + // Calculate alignment requirements and pad buffer if necessary. + align, alignOffset := opt.alignment() + padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset) + if padLen != 0 { + padIPv6Option(optBuffer[:padLen]) + totalLength += padLen + optBuffer = optBuffer[padLen:] + } + + l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:]) + optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier()) + optBuffer[ipv6ExtHdrOptionLengthOffset] = l + l += ipv6ExtHdrOptionPayloadOffset + totalLength += int(l) + optBuffer = optBuffer[l:] + } + padded := padIPv6OptionsLength(totalLength) + if padded != totalLength { + padIPv6Option(optBuffer[:padded-totalLength]) + totalLength = padded + } + wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords + if wordsLen > math.MaxUint8 { + panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen)) + } + b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader + b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen) + return totalLength +} + +// IPv6SerializableHopByHopOption provides serialization for hop by hop options. +type IPv6SerializableHopByHopOption interface { + // identifier returns the option identifier of this Hop by Hop option. + identifier() IPv6ExtHdrOptionIdentifier + + // length returns the *payload* size of the option (not considering the type + // and length fields). + length() uint8 + + // alignment returns the alignment requirements from this option. + // + // Alignment requirements take the form [align]n + offset as specified in + // RFC 8200 section 4.2. The alignment requirement is on the offset between + // the option type byte and the start of the hop by hop header. + // + // align must be a power of 2. + alignment() (align int, offset int) + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) uint8 +} + +var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil) + +// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in +// RFC 2711 section 2.1. +type IPv6RouterAlertOption struct { + Value IPv6RouterAlertValue +} + +// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option. +type IPv6RouterAlertValue uint16 + +const ( + // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener + // Discovery message as defined in RFC 2711 section 2.1. + IPv6RouterAlertMLD IPv6RouterAlertValue = 0 + // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as + // defined in RFC 2711 section 2.1. + IPv6RouterAlertRSVP IPv6RouterAlertValue = 1 + // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active + // Networks message as defined in RFC 2711 section 2.1. + IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2 + + // ipv6RouterAlertPayloadLength is the length of the Router Alert payload + // as defined in RFC 2711. + ipv6RouterAlertPayloadLength = 2 + + // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the + // Router Alert option defined as 2n+0 in RFC 2711. + ipv6RouterAlertAlignmentRequirement = 2 + + // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset + // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section + // 2.1. + ipv6RouterAlertAlignmentOffsetRequirement = 0 +) + +// UnknownAction implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction { + return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {} + +// identifier implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier { + return ipv6RouterAlertHopByHopOptionIdentifier +} + +// length implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) length() uint8 { + return ipv6RouterAlertPayloadLength +} + +// alignment implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) alignment() (int, int) { + // From RFC 2711 section 2.1: + // Alignment requirement: 2n+0. + return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 { + binary.BigEndian.PutUint16(b, uint16(o.Value)) + return ipv6RouterAlertPayloadLength +} + +// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers. +type IPv6ExtHdrSerializer []IPv6SerializableExtHdr + +// Serialize serializes the provided list of IPv6 extension headers into b. +// +// Note, b must be of sufficient size to hold all the headers in s. See +// IPv6ExtHdrSerializer.Length for details on the getting the total size of a +// serialized IPv6ExtHdrSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +// +// Serialize takes the transportProtocol value to be used as the last extension +// header's Next Header value and returns the header identifier of the first +// serialized extension header and the total serialized length. +func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) { + nextHeader := uint8(transportProtocol) + if len(s) == 0 { + return nextHeader, 0 + } + var totalLength int + for i, h := range s[:len(s)-1] { + length := h.serializeInto(uint8(s[i+1].identifier()), b) + b = b[length:] + totalLength += length + } + totalLength += s[len(s)-1].serializeInto(nextHeader, b) + return uint8(s[0].identifier()), totalLength +} + +// Length returns the total number of bytes required to serialize the extension +// headers. +func (s IPv6ExtHdrSerializer) Length() int { + var totalLength int + for _, h := range s { + totalLength += h.length() + } + return totalLength +} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go index ab20c5f37..5d2296353 100644 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool func TestIPv6UnknownExtHdrOption(t *testing.T) { tests := []struct { name string - identifier IPv6ExtHdrOptionIndentifier + identifier IPv6ExtHdrOptionIdentifier expectedUnknownAction IPv6OptionUnknownAction }{ { @@ -990,3 +991,331 @@ func TestIPv6ExtHdrIter(t *testing.T) { }) } } + +var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil) + +// dummyHbHOptionSerializer provides a generic implementation of +// IPv6SerializableHopByHopOption for use in tests. +type dummyHbHOptionSerializer struct { + id IPv6ExtHdrOptionIdentifier + payload []byte + align int + alignOffset int +} + +// identifier implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier { + return s.id +} + +// length implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) length() uint8 { + return uint8(len(s.payload)) +} + +// alignment implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) alignment() (int, int) { + align := 1 + if s.align != 0 { + align = s.align + } + return align, s.alignOffset +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 { + return uint8(copy(b, s.payload)) +} + +func TestIPv6HopByHopSerializer(t *testing.T) { + validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + dummy, ok := serializable.(*dummyHbHOptionSerializer) + if !ok { + t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable) + } + unknown, ok := deserialized.(*IPv6UnknownExtHdrOption) + if !ok { + t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{}) + } + if dummy.id != unknown.Identifier { + t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id) + } + if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" { + t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff) + } + } + tests := []struct { + name string + nextHeader uint8 + options []IPv6SerializableHopByHopOption + expect []byte + validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption) + }{ + { + name: "single option", + nextHeader: 13, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 15, + payload: []byte{9, 8, 7, 6}, + }, + }, + expect: []byte{13, 0, 15, 4, 9, 8, 7, 6}, + validate: validateDummies, + }, + { + name: "short option padN zero", + nextHeader: 88, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5}, + }, + }, + expect: []byte{88, 0, 22, 2, 4, 5, 1, 0}, + validate: validateDummies, + }, + { + name: "short option pad1", + nextHeader: 11, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 33, + payload: []byte{1, 2, 3}, + }, + }, + expect: []byte{11, 0, 33, 3, 1, 2, 3, 0}, + validate: validateDummies, + }, + { + name: "long option padN", + nextHeader: 55, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 77, + payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + }, + expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options align 2n", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 2, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0}, + validate: validateDummies, + }, + { + name: "two options align 8n+1", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 8, + alignOffset: 1, + }, + }, + expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0}, + validate: validateDummies, + }, + { + name: "no options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{}, + expect: []byte{33, 0, 1, 4, 0, 0, 0, 0}, + }, + { + name: "Router Alert", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}}, + expect: []byte{33, 0, 5, 2, 0, 0, 1, 0}, + validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + routerAlert, ok := deserialized.(*IPv6RouterAlertOption) + if !ok { + t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized) + } + if routerAlert.Value != IPv6RouterAlertMLD { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD) + } + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6SerializableHopByHopExtHdr(test.options) + length := s.length() + if length != len(test.expect) { + t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect)) + } + b := make([]byte, length) + for i := range b { + // Fill the buffer with ones to ensure all padding is correctly set. + b[i] = 0xFF + } + if got := s.serializeInto(test.nextHeader, b); got != length { + t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length) + } + if diff := cmp.Diff(test.expect, b); diff != "" { + t.Fatalf("serialization mismatch (-want +got):\n%s", diff) + } + + // Deserialize the options and verify them. + optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit + iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter() + for _, testOpt := range test.options { + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + test.validate(t, testOpt, opt) + } + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if !done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + }) + } +} + +var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil) + +// dummyIPv6ExtHdrSerializer provides a generic implementation of +// IPv6SerializableExtHdr for use in tests. +// +// The dummy header always carries the nextHeader value in the first byte. +type dummyIPv6ExtHdrSerializer struct { + id IPv6ExtensionHeaderIdentifier + headerContents []byte +} + +// identifier implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier { + return s.id +} + +// length implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) length() int { + return len(s.headerContents) + 1 +} + +// serializeInto implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int { + b[0] = nextHeader + return copy(b[1:], s.headerContents) + 1 +} + +func TestIPv6ExtHdrSerializer(t *testing.T) { + tests := []struct { + name string + headers []IPv6SerializableExtHdr + nextHeader tcpip.TransportProtocolNumber + expectSerialized []byte + expectNextHeader uint8 + }{ + { + name: "one header", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 15, + headerContents: []byte{1, 2, 3, 4}, + }, + }, + nextHeader: TCPProtocolNumber, + expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4}, + expectNextHeader: 15, + }, + { + name: "two headers", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 22, + headerContents: []byte{1, 2, 3}, + }, + &dummyIPv6ExtHdrSerializer{ + id: 23, + headerContents: []byte{4, 5, 6}, + }, + }, + nextHeader: ICMPv6ProtocolNumber, + expectSerialized: []byte{ + 23, 1, 2, 3, + byte(ICMPv6ProtocolNumber), 4, 5, 6, + }, + expectNextHeader: 22, + }, + { + name: "no headers", + headers: []IPv6SerializableExtHdr{}, + nextHeader: UDPProtocolNumber, + expectSerialized: []byte{}, + expectNextHeader: byte(UDPProtocolNumber), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6ExtHdrSerializer(test.headers) + l := s.Length() + if got, want := l, len(test.expectSerialized); got != want { + t.Fatalf("got serialized length = %d, want = %d", got, want) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with garbage to make sure we're writing to all bytes. + b[i] = 0xFF + } + nextHeader, serializedLen := s.Serialize(test.nextHeader, b) + if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader { + t.Errorf( + "got s.Serialize(..) = (%d, %d), want = (%d, %d)", + nextHeader, + serializedLen, + test.expectNextHeader, + len(test.expectSerialized), + ) + } + if diff := cmp.Diff(test.expectSerialized, b); diff != "" { + t.Errorf("serialization mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go index 018555a26..9d09f32eb 100644 --- a/pkg/tcpip/header/ipv6_fragment.go +++ b/pkg/tcpip/header/ipv6_fragment.go @@ -27,12 +27,11 @@ const ( idV6 = 4 ) -// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the -// fields of a packet that needs to be encoded. -type IPv6FragmentFields struct { - // NextHeader is the "next header" field of an IPv6 fragment. - NextHeader uint8 +var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil) +// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment +// extension header as defined in RFC 8200 section 4.5. +type IPv6SerializableFragmentExtHdr struct { // FragmentOffset is the "fragment offset" field of an IPv6 fragment. FragmentOffset uint16 @@ -43,6 +42,29 @@ type IPv6FragmentFields struct { Identification uint32 } +// identifier implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6FragmentHeader +} + +// length implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) length() int { + return IPv6FragmentHeaderSize +} + +// serializeInto implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int { + // Prevent too many bounds checks. + _ = b[IPv6FragmentHeaderSize:] + binary.BigEndian.PutUint32(b[idV6:], h.Identification) + binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift) + b[nextHdrFrag] = nextHeader + if h.M { + b[more] |= ipv6FragmentExtHdrMFlagMask + } + return IPv6FragmentHeaderSize +} + // IPv6Fragment represents an ipv6 fragment header stored in a byte array. // Most of the methods of IPv6Fragment access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. @@ -58,16 +80,6 @@ const ( IPv6FragmentHeaderSize = 8 ) -// Encode encodes all the fields of the ipv6 fragment. -func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { - b[nextHdrFrag] = i.NextHeader - binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) - if i.M { - b[more] |= 1 - } - binary.BigEndian.PutUint32(b[idV6:], i.Identification) -} - // IsValid performs basic validation on the fragment header. func (b IPv6Fragment) IsValid() bool { return len(b) >= IPv6FragmentHeaderSize diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go index 17a49d4fa..b5540bf66 100644 --- a/pkg/tcpip/header/ipversion_test.go +++ b/pkg/tcpip/header/ipversion_test.go @@ -22,7 +22,7 @@ import ( func TestIPv4(t *testing.T) { b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize}) + b.Encode(&header.IPv4Fields{}) const want = header.IPv4Version if v := header.IPVersion(b); v != want { diff --git a/pkg/tcpip/header/mld.go b/pkg/tcpip/header/mld.go new file mode 100644 index 000000000..ffe03c76a --- /dev/null +++ b/pkg/tcpip/header/mld.go @@ -0,0 +1,103 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // MLDMinimumSize is the minimum size for an MLD message. + MLDMinimumSize = 20 + + // MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as + // per RFC 2710 section 3. + MLDHopLimit = 1 + + // mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay + // field within MLD. + mldMaximumResponseDelayOffset = 0 + + // mldMulticastAddressOffset is the offset to the Multicast Address field + // within MLD. + mldMulticastAddressOffset = 4 +) + +// MLD is a Multicast Listener Discovery message in an ICMPv6 packet. +// +// MLD will only contain the body of an ICMPv6 packet. +// +// As per RFC 2710 section 3, MLD messages have the following format (MLD only +// holds the bytes after the first four bytes in the diagram below): +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Maximum Response Delay | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// + + +// | | +// + Multicast Address + +// | | +// + + +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +type MLD []byte + +// MaximumResponseDelay returns the Maximum Response Delay. +func (m MLD) MaximumResponseDelay() time.Duration { + // As per RFC 2710 section 3.4: + // + // The Maximum Response Delay field is meaningful only in Query + // messages, and specifies the maximum allowed delay before sending a + // responding Report, in units of milliseconds. In all other messages, + // it is set to zero by the sender and ignored by receivers. + return time.Duration(binary.BigEndian.Uint16(m[mldMaximumResponseDelayOffset:])) * time.Millisecond +} + +// SetMaximumResponseDelay sets the Maximum Response Delay field. +// +// maxRespDelayMS is the value in milliseconds. +func (m MLD) SetMaximumResponseDelay(maxRespDelayMS uint16) { + binary.BigEndian.PutUint16(m[mldMaximumResponseDelayOffset:], maxRespDelayMS) +} + +// MulticastAddress returns the Multicast Address. +func (m MLD) MulticastAddress() tcpip.Address { + // As per RFC 2710 section 3.5: + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + // + // In a Report or Done message, the Multicast Address field holds a + // specific IPv6 multicast address to which the message sender is + // listening or is ceasing to listen, respectively. + return tcpip.Address(m[mldMulticastAddressOffset:][:IPv6AddressSize]) +} + +// SetMulticastAddress sets the Multicast Address field. +func (m MLD) SetMulticastAddress(multicastAddress tcpip.Address) { + if n := copy(m[mldMulticastAddressOffset:], multicastAddress); n != IPv6AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, IPv6AddressSize)) + } +} diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go new file mode 100644 index 000000000..0cecf10d4 --- /dev/null +++ b/pkg/tcpip/header/mld_test.go @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestMLD(t *testing.T) { + b := []byte{ + // Maximum Response Delay + 0, 0, + + // Reserved + 0, 0, + + // MulticastAddress + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, + } + + const maxRespDelay = 513 + binary.BigEndian.PutUint16(b, maxRespDelay) + + mld := MLD(b) + + if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want { + t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) + } + + const newMaxRespDelay = 1234 + mld.SetMaximumResponseDelay(newMaxRespDelay) + if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want { + t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) + } + + if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want { + t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want) + } + + multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}) + mld.SetMulticastAddress(multicastAddress) + if got := mld.MulticastAddress(); got != multicastAddress { + t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress) + } +} diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 5d3975c56..554242f0c 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -298,7 +298,7 @@ func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { return it, nil } -// Serialize serializes the provided list of NDP options into o. +// Serialize serializes the provided list of NDP options into b. // // Note, b must be of sufficient size to hold all the options in s. See // NDPOptionsSerializer.Length for details on the getting the total size diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 5ca75c834..2042f214a 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -109,6 +109,9 @@ traverseExtensions: fragOffset = extHdr.FragmentOffset() fragMore = extHdr.More() } + rawPayload := it.AsRawHeader(true /* consume */) + extensionsSize = dataClone.Size() - rawPayload.Buf.Size() + break traverseExtensions case header.IPv6RawPayloadHeader: // We've found the payload after any extensions. diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 98bdd29db..a6d4fcd59 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -36,10 +36,10 @@ const ( // UDPFields contains the fields of a UDP packet. It is used to describe the // fields of a packet that needs to be encoded. type UDPFields struct { - // SrcPort is the "source port" field of a UDP packet. + // SrcPort is the "Source Port" field of a UDP packet. SrcPort uint16 - // DstPort is the "destination port" field of a UDP packet. + // DstPort is the "Destination Port" field of a UDP packet. DstPort uint16 // Length is the "length" field of a UDP packet. @@ -64,52 +64,57 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. +// SourcePort returns the "Source Port" field of the UDP header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "destination port" field of the udp header. +// DestinationPort returns the "Destination Port" field of the UDP header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "length" field of the udp header. +// Length returns the "Length" field of the UDP header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } // Payload returns the data contained in the UDP datagram. func (b UDP) Payload() []byte { - return b[UDPMinimumSize:] + return b[:b.Length()][UDPMinimumSize:] } -// Checksum returns the "checksum" field of the udp header. +// Checksum returns the "checksum" field of the UDP header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the udp header. +// SetSourcePort sets the "source port" field of the UDP header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the udp header. +// SetDestinationPort sets the "destination port" field of the UDP header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the udp header. +// SetChecksum sets the "checksum" field of the UDP header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the udp header. +// SetLength sets the "length" field of the UDP header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// CalculateChecksum calculates the checksum of the udp packet, given the +// PayloadLength returns the length of the payload following the UDP header. +func (b UDP) PayloadLength() uint16 { + return b.Length() - UDPMinimumSize +} + +// CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 39ca774ef..973f06cbc 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -9,7 +9,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c95aef63c..0efbfb22b 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -32,7 +31,7 @@ type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO - Route stack.Route + Route *stack.Route } // Notification is the interface for receiving notification from the packet @@ -271,21 +270,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := PacketInfo{ - Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }), - Proto: 0, - GSO: nil, - } - - e.q.Write(p) - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index 3eef7cd56..beefcd008 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -62,7 +62,7 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { // WritePacket implements stack.LinkEndpoint. func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt) return e.Endpoint.WritePacket(r, gso, proto, pkt) } @@ -71,7 +71,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) + e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, proto) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 975309fc8..cb94cbea6 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher } switch sa.(type) { case *unix.SockaddrLinklayer: - // enable PACKET_FANOUT mode is the underlying socket is - // of type AF_PACKET. - const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG + // Enable PACKET_FANOUT mode if the underlying socket is of type + // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will + // prevent gvisor from receiving fragmented packets and the host does the + // reassembly on our behalf before delivering the fragments. This makes it + // hard to test fragmentation reassembly code in Netstack. + const fanoutType = unix.PACKET_FANOUT_HASH fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) @@ -410,7 +413,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt) } var builder iovec.Builder @@ -453,7 +456,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { if e.hdrSize > 0 { - e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress(), pkt.NetworkProtocolNumber, pkt) } var vnetHdrBuf []byte @@ -558,11 +561,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) -} - // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 709f829c8..ce4da7230 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -183,9 +183,8 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) defer c.cleanup() - r := &stack.Route{ - RemoteLinkAddress: raddr, - } + var r stack.Route + r.ResolveWith(raddr) // Build payload. payload := buffer.NewView(plen) @@ -220,7 +219,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { + if err := c.ep.WritePacket(&r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -325,9 +324,9 @@ func TestPreserveSrcAddress(t *testing.T) { // Set LocalLinkAddress in route to the value of the bridged address. r := &stack.Route{ - RemoteLinkAddress: raddr, - LocalLinkAddress: baddr, + LocalLinkAddress: baddr, } + r.ResolveWith(raddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ // WritePacket panics given a prependable with anything less than diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 38aa694e4..edca57e4e 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - // There should be an ethernet header at the beginning of vv. - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. - return tcpip.ErrBadAddress - } - linkHeader := header.Ethernet(hdr) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) - - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareLoopback diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index e7493e5c5..cbda59775 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 56a611825..22e79ce3a 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -17,7 +17,6 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco return tcpip.ErrNoRoute } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // WriteRawPacket doesn't get a route or network address, so there's - // nowhere to write this. - return tcpip.ErrNoRoute -} - // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index 2cdb23475..00b42b924 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index d40de54df..0ee54c3d5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -19,7 +19,6 @@ package nested import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return e.child.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return e.child.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint. func (e *Endpoint) Wait() { e.child.Wait() diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index 3922c2a04..9a1b0c0c2 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -36,14 +36,14 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { // WritePacket implements stack.LinkEndpoint.WritePacket. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress(), r.LocalLinkAddress, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress(), pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, proto) diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 523b0d24b..25c364391 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -55,7 +55,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this // endpoint is the local address on the other end. - e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), })) @@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint. -func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - panic("not implemented") -} - // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 1d0079bd6..5bea598eb 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -13,7 +13,6 @@ go_library( "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index fc1e34fc7..27667f5f0 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -156,7 +155,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. newRoute := r.Clone() - pkt.EgressRoute = &newRoute + pkt.EgressRoute = newRoute pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -183,7 +182,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // the route here to ensure it doesn't get released while the // packet is still in our queue. newRoute := pkt.EgressRoute.Clone() - pkt.EgressRoute = &newRoute + pkt.EgressRoute = newRoute if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() @@ -197,13 +196,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB return enqueued, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // TODO(gvisor.dev/issue/3267): Queue these packets as well once - // WriteRawPacket takes PacketBuffer instead of VectorisedView. - return e.lower.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index dc239a0d0..2777f1411 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) { const count = 1000000 var wg sync.WaitGroup + defer wg.Wait() wg.Add(1) go func() { defer wg.Done() @@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) { } }() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } + for i := 0; i < count; i++ { + n := 1 + rr.Intn(80) + rb := rx.Pull() + for rb == nil { + rb = rx.Pull() + } - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } + if n != len(rb) { + t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) + } - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } + for j := range rb { + if v := byte(rr.Intn(256)); v != rb[j] { + t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) } - - rx.Flush() } - }() - wg.Wait() + rx.Flush() + } } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 7fb8a6c49..5660418fa 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -204,7 +204,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt) views := pkt.Views() // Transmit the packet. @@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - views := vv.Views() - // Transmit the packet. - e.mu.Lock() - ok := e.tx.transmit(views...) - e.mu.Unlock() - - if !ok { - return tcpip.ErrWouldBlock - } - - return nil -} - // dispatchLoop reads packets from the rx queue in a loop and dispatches them // to the network stack. func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 22d5c97f1..7131392cc 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -260,9 +260,8 @@ func TestSimpleSend(t *testing.T) { defer c.cleanup() // Prepare route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) for iters := 1000; iters > 0; iters-- { func() { @@ -342,9 +341,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) // Set both remote and local link address in route. r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - LocalLinkAddress: newLocalLinkAddress, + LocalLinkAddress: newLocalLinkAddress, } + r.ResolveWith(remoteLinkAddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ // WritePacket panics given a prependable with anything less than @@ -395,9 +394,8 @@ func TestFillTxQueue(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -444,9 +442,8 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { c.txq.rx.Flush() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -509,9 +506,8 @@ func TestFillTxMemory(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -557,9 +553,8 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 560477926..8d9a91020 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -53,16 +53,35 @@ type endpoint struct { nested.Endpoint writer io.Writer maxPCAPLen uint32 + logPrefix string } var _ stack.GSOEndpoint = (*endpoint)(nil) var _ stack.LinkEndpoint = (*endpoint)(nil) var _ stack.NetworkDispatcher = (*endpoint)(nil) +type direction int + +const ( + directionSend = iota + directionRecv +) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - sniffer := &endpoint{} + return NewWithPrefix(lower, "") +} + +// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets prefixed with logPrefix as they traverse +// the endpoint. +// +// logPrefix is prepended to the log line without any separators. +// E.g. logPrefix = "NIC:en0/" will produce log lines like +// "NIC:en0/send udp [...]". +func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint { + sniffer := &endpoint{logPrefix: logPrefix} sniffer.Endpoint.Init(lower, sniffer) return sniffer } @@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, pkt) + e.dumpPacket(directionRecv, nil, protocol, pkt) e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } -func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + logPacket(e.logPrefix, dir, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Size() @@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // forwards the request to the lower endpoint. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return e.Endpoint.WriteRawPacket(vv) -} - -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -201,11 +212,26 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool + var directionPrefix string + switch dir { + case directionSend: + directionPrefix = "send" + case directionRecv: + directionPrefix = "recv" + default: + panic(fmt.Sprintf("unrecognized direction: %d", dir)) + } + // Clone the packet buffer to not modify the original. // // We don't clone the original packet buffer so that the new packet buffer // does not have any of its headers set. - pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) + // + // We trim the link headers from the cloned buffer as the sniffer doesn't + // handle link headers. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + vv.TrimFront(len(pkt.LinkHeader().View())) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) switch protocol { case header.IPv4ProtocolNumber: if ok := parse.IPv4(pkt); !ok { @@ -243,15 +269,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( - "%s arp %s (%s) -> %s (%s) valid:%t", + "%s%s arp %s (%s) -> %s (%s) valid:%t", prefix, + directionPrefix, tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), arp.IsValid(), ) return default: - log.Infof("%s unknown network protocol", prefix) + log.Infof("%s%s unknown network protocol", prefix, directionPrefix) return } @@ -295,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P icmpType = "info reply" } } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.ICMPv6ProtocolNumber: @@ -330,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6RedirectMsg: icmpType = "redirect message" } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.UDPProtocolNumber: @@ -386,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P } default: - log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto) + log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto) return } @@ -394,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P details += fmt.Sprintf(" gso: %+v", gso) } - log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) + log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details) } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 0243424f6..86f14db76 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "tun_endpoint_refs.go", package = "tun", prefix = "tunEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tunEndpoint", }, @@ -28,6 +28,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f94491026..a364c5801 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -76,29 +76,13 @@ func (d *Device) Release(ctx context.Context) { } } -// NICID returns the NIC ID of the device. -// -// Must only be called after the device has been attached to an endpoint. -func (d *Device) NICID() tcpip.NICID { - d.mu.RLock() - defer d.mu.RUnlock() - - if d.endpoint == nil { - panic("called NICID on a device that has not been attached") - } - - return d.endpoint.nicID -} - // SetIff services TUNSETIFF ioctl(2) request. -// -// Returns true if a new NIC was created; false if an existing one was attached. -func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) { +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { d.mu.Lock() defer d.mu.Unlock() if d.endpoint != nil { - return false, syserror.EINVAL + return syserror.EINVAL } // Input validations. @@ -106,7 +90,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) isTap := flags&linux.IFF_TAP != 0 supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { - return false, syserror.EINVAL + return syserror.EINVAL } prefix := "tun" @@ -119,18 +103,18 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) linkCaps |= stack.CapabilityResolutionRequired } - endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps) + endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return false, syserror.EINVAL + return syserror.EINVAL } d.endpoint = endpoint d.notifyHandle = d.endpoint.AddNotify(d) d.flags = flags - return created, nil + return nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { for { // 1. Try to attach to an existing NIC. if name != "" { @@ -138,19 +122,18 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE endpoint, ok := linkEP.(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, false, syserror.EOPNOTSUPP + return nil, syserror.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. continue } - return endpoint, false, nil + return endpoint, nil } } // 2. Creating a new NIC. id := tcpip.NICID(s.UniqueID()) - // TODO(gvisor.dev/1486): enable leak check for tunEndpoint. endpoint := &tunEndpoint{ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), stack: s, @@ -158,6 +141,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE name: name, isTap: prefix == "tap", } + endpoint.InitRefs() endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) @@ -167,12 +151,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE }) switch err { case nil: - return endpoint, true, nil + return endpoint, nil case tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: - return nil, false, syserror.EINVAL + return nil, syserror.EINVAL } } } @@ -280,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" { return nil, false } @@ -288,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader().View().IsEmpty() { - d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt) } vv.AppendView(info.Pkt.LinkHeader().View()) } diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index ee84c3d96..9b4602c1b 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/gate", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], @@ -25,7 +24,6 @@ go_test( library = ":waitable", deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index b152a0f26..cf0077f43 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -24,7 +24,6 @@ package waitable import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, err } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if !e.writeGate.Enter() { - return nil - } - - err := e.lower.WriteRawPacket(vv) - e.writeGate.Leave() - return err -} - // WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, // and waits for inflight ones to finish before returning. func (e *Endpoint) WaitWrite() { diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 94827fc56..cf7fb5126 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -18,7 +18,6 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack. return pkts.Len(), nil } -func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - e.writeCount++ - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unimplemented") diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index c118a2929..9ebf31b78 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -7,13 +7,16 @@ go_test( size = "small", srcs = [ "ip_test.go", + "multicast_group_test.go", ], deps = [ "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index b40dde96b..8a6bcfc2c 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -30,5 +30,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7df77c66e..3d5c0d270 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -18,6 +18,7 @@ package arp import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -30,17 +31,15 @@ import ( const ( // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber - - // ProtocolAddress is the address expected by the ARP endpoint. - ProtocolAddress = tcpip.Address("arp") ) -var _ stack.AddressableEndpoint = (*endpoint)(nil) +// ARP endpoints need to implement stack.NetworkEndpoint because the stack +// considers the layer above the link-layer a network layer; the only +// facility provided by the stack to deliver packets to a layer above +// the link-layer is via stack.NetworkEndpoint.HandlePacket. var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { - stack.AddressableEndpointState - protocol *protocol // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. @@ -86,7 +85,7 @@ func (e *endpoint) Disable() { } // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. -func (e *endpoint) DefaultTTL() uint8 { +func (*endpoint) DefaultTTL() uint8 { return 0 } @@ -99,29 +98,27 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.AddressableEndpointState.Cleanup() -} +func (*endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. -func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { +func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return ProtocolNumber } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } @@ -144,34 +141,43 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - // As per RFC 826, under Packet Reception: - // Swap hardware and protocol fields, putting the local hardware and - // protocol addresses in the sender fields. - // - // Send the packet to the (new) target hardware address on the same - // hardware on which the request was received. - origSender := h.HardwareAddressSender() - r.RemoteLinkAddress = tcpip.LinkAddress(origSender) respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) + respPkt.NetworkProtocolNumber = ProtocolNumber packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), origSender) - copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress()) + if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + origSender := h.HardwareAddressSender() + if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize)) + } + if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -199,15 +205,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { + stack *stack.Stack } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } func (p *protocol) MinimumPacketSize() int { return header.ARPSize } func (p *protocol) DefaultPrefixLen() int { return 0 } -func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - h := header.ARP(v) - return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { + return "", "" } func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { @@ -217,7 +223,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L linkAddrCache: linkAddrCache, nud: nud, } - e.AddressableEndpointState.Init(e) return e } @@ -227,26 +232,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - r := &stack.Route{ - NetProto: ProtocolNumber, - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + if len(remoteLinkAddr) == 0 { + remoteLinkAddr = header.EthernetBroadcastAddress } - if len(r.RemoteLinkAddress) == 0 { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + + nicID := nic.ID() + if len(localAddr) == 0 { + addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if err != nil { + return err + } + + if len(addr.Address) == 0 { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addr.Address + } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + pkt.NetworkProtocolNumber = ProtocolNumber h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), linkEP.LinkAddress()) - copy(h.ProtocolAddressSender(), localAddr) - copy(h.ProtocolAddressTarget(), addr) - - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -282,10 +305,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } // NewProtocol returns an ARP network protocol. -// -// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" -// address must be added to every NIC that should respond to ARP requests. See -// ProtocolAddress for more details. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 626af975a..0fb373612 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,11 @@ func (t eventType) String() string { type eventInfo struct { eventType eventType nicID tcpip.NICID - addr tcpip.Address - linkAddr tcpip.LinkAddress - state stack.NeighborState + entry stack.NeighborEntry } func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) } // arpDispatcher implements NUDDispatcher to validate the dispatching of @@ -96,35 +95,29 @@ type arpDispatcher struct { var _ stack.NUDDispatcher = (*arpDispatcher)(nil) -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryChanged, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryRemoved, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } @@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" { + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { return fmt.Errorf("got invalid event (-got +want):\n%s", diff) } case <-ctx.Done(): @@ -207,9 +200,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { t.Fatalf("AddAddress for ipv4 failed: %v", err) } } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("AddAddress for arp failed: %v", err) - } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -329,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { copy(h.HardwareAddressSender(), test.senderLinkAddr) copy(h.ProtocolAddressSender(), test.senderAddr) copy(h.ProtocolAddressTarget(), test.targetAddr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) if !test.isValid { // No packets should be sent after receiving an invalid ARP request. @@ -373,9 +363,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { wantEvent := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: test.senderAddr, - linkAddr: tcpip.LinkAddress(test.senderLinkAddr), - state: stack.Stale, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + State: stack.Stale, + }, } if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { t.Fatal(err) @@ -404,9 +396,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) } - if got, want := neigh.LocalAddr, stackAddr; got != want { - t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want) - } if got, want := neigh.State, stack.Stale; got != want { t.Errorf("got neighbor State = %s, want = %s", got, want) } @@ -423,43 +412,168 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.LinkEndpoint + + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) Promiscuous() bool { + return false +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + } + r.ResolveWith(remoteLinkAddr) + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + + testAddr := tcpip.Address([]byte{1, 2, 3, 4}) + tests := []struct { name string + nicAddr tcpip.Address + localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Unicast with no local address available", remoteLinkAddr: remoteLinkAddr, - expectLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, }, { - name: "Multicast", + name: "Multicast with no local address available", remoteLinkAddr: "", - expectLinkAddr: header.EthernetBroadcastAddress, + expectedErr: tcpip.ErrNetworkUnreachable, }, } for _, test := range tests { - p := arp.NewProtocol(nil) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err) - } + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + } + } - if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) - } + // We pass a test network interface to LinkAddressRequest with the same + // NIC ID and link endpoint used by the NIC we created earlier so that we + // can mock a link address request and observe the packets sent to the + // link endpoint even though the stack uses the real NIC to validate the + // local address. + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) + } + + rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) + } + }) } } diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 47fb63290..429af69ee 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -18,7 +18,6 @@ go_template_instance( go_library( name = "fragmentation", srcs = [ - "frag_heap.go", "fragmentation.go", "reassembler.go", "reassembler_list.go", @@ -38,7 +37,6 @@ go_test( name = "fragmentation_test", size = "small", srcs = [ - "frag_heap_test.go", "fragmentation_test.go", "reassembler_test.go", ], @@ -47,6 +45,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go deleted file mode 100644 index 0b570d25a..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "fmt" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -type fragment struct { - offset uint16 - vv buffer.VectorisedView -} - -type fragHeap []fragment - -func (h *fragHeap) Len() int { - return len(*h) -} - -func (h *fragHeap) Less(i, j int) bool { - return (*h)[i].offset < (*h)[j].offset -} - -func (h *fragHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *fragHeap) Push(x interface{}) { - *h = append(*h, x.(fragment)) -} - -func (h *fragHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[:n-1] - return x -} - -// reassamble empties the heap and returns a VectorisedView -// containing a reassambled version of the fragments inside the heap. -func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { - curr := heap.Pop(h).(fragment) - views := curr.vv.Views() - size := curr.vv.Size() - - if curr.offset != 0 { - return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) - } - - for h.Len() > 0 { - curr := heap.Pop(h).(fragment) - if int(curr.offset) < size { - curr.vv.TrimFront(size - int(curr.offset)) - } else if int(curr.offset) > size { - return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) - } - size += curr.vv.Size() - views = append(views, curr.vv.Views()...) - } - return buffer.NewVectorisedView(size, views), nil -} diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index ed502a473..1af87d713 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -46,9 +46,17 @@ const ( ) var ( - // ErrInvalidArgs indicates to the caller that that an invalid argument was + // ErrInvalidArgs indicates to the caller that an invalid argument was // provided. ErrInvalidArgs = errors.New("invalid args") + + // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps + // with another one. + ErrFragmentOverlap = errors.New("overlapping fragments") + + // ErrFragmentConflict indicates that, during reassembly, some fragments are + // in conflict with one another. + ErrFragmentConflict = errors.New("conflicting fragments") ) // FragmentID is the identifier for a fragment. @@ -71,16 +79,25 @@ type FragmentID struct { // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. type Fragmentation struct { - mu sync.Mutex - highLimit int - lowLimit int - reassemblers map[FragmentID]*reassembler - rList reassemblerList - size int - timeout time.Duration - blockSize uint16 - clock tcpip.Clock - releaseJob *tcpip.Job + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[FragmentID]*reassembler + rList reassemblerList + size int + timeout time.Duration + blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job + timeoutHandler TimeoutHandler +} + +// TimeoutHandler is consulted if a packet reassembly has timed out. +type TimeoutHandler interface { + // OnReassemblyTimeout will be called with the first fragment (or nil, if the + // first fragment has not been received) of a packet whose reassembly has + // timed out. + OnReassemblyTimeout(pkt *stack.PacketBuffer) } // NewFragmentation creates a new Fragmentation. @@ -97,7 +114,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -111,12 +128,13 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea } f := &Fragmentation{ - reassemblers: make(map[FragmentID]*reassembler), - highLimit: highMemoryLimit, - lowLimit: lowMemoryLimit, - timeout: reassemblingTimeout, - blockSize: blockSize, - clock: clock, + reassemblers: make(map[FragmentID]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + blockSize: blockSize, + clock: clock, + timeoutHandler: timeoutHandler, } f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) @@ -137,7 +155,7 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) ( + id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -152,10 +170,9 @@ func (f *Fragmentation) Process( return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } - if l := vv.Size(); l < int(fragmentSize) { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + if l := pkt.Data.Size(); l != int(fragmentSize) { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } - vv.CapLength(int(fragmentSize)) f.mu.Lock() r, ok := f.reassemblers[id] @@ -173,19 +190,19 @@ func (f *Fragmentation) Process( } f.mu.Unlock() - res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) + res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() - f.release(r) + f.release(r, false /* timedOut */) f.mu.Unlock() return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed if done { - f.release(r) + f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. @@ -195,14 +212,14 @@ func (f *Fragmentation) Process( if tail == nil { break } - f.release(tail) + f.release(tail, false /* timedOut */) } } f.mu.Unlock() return res, firstFragmentProto, done, nil } -func (f *Fragmentation) release(r *reassembler) { +func (f *Fragmentation) release(r *reassembler, timedOut bool) { // Before releasing a fragment we need to check if r is already marked as done. // Otherwise, we would delete it twice. if r.checkDoneOrMark() { @@ -216,6 +233,10 @@ func (f *Fragmentation) release(r *reassembler) { log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) f.size = 0 } + + if h := f.timeoutHandler; timedOut && h != nil { + h.OnReassemblyTimeout(r.pkt) + } } // releaseReassemblersLocked releases already-expired reassemblers, then @@ -238,31 +259,31 @@ func (f *Fragmentation) releaseReassemblersLocked() { break } // If the oldest reassembler has already expired, release it. - f.release(r) + f.release(r, true /* timedOut*/) } } // PacketFragmenter is the book-keeping struct for packet fragmentation. type PacketFragmenter struct { - transportHeader buffer.View - data buffer.VectorisedView - reserve int - innerMTU int - fragmentCount int - currentFragment int - fragmentOffset int + transportHeader buffer.View + data buffer.VectorisedView + reserve int + fragmentPayloadLen int + fragmentCount int + currentFragment int + fragmentOffset int } // MakePacketFragmenter prepares the struct needed for packet fragmentation. // // pkt is the packet to be fragmented. // -// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can // have. // // reserve is the number of bytes that should be reserved for the headers in // each generated fragment. -func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { +func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be // repeated in each fragment. However we do not currently support any header // of that kind yet, so the following computation is valid for both IPv4 and @@ -273,13 +294,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa var fragmentableData buffer.VectorisedView fragmentableData.AppendView(pkt.TransportHeader().View()) fragmentableData.Append(pkt.Data) - fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen return PacketFragmenter{ - data: fragmentableData, - reserve: reserve, - innerMTU: innerMTU, - fragmentCount: fragmentCount, + data: fragmentableData, + reserve: reserve, + fragmentPayloadLen: int(fragmentPayloadLen), + fragmentCount: int(fragmentCount), } } @@ -302,7 +323,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, }) // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) offset := pf.fragmentOffset pf.fragmentOffset += copied diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index d3c7d7f92..3a79688a8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // reassembleTimeout is dummy timeout used for testing, where the clock never @@ -40,13 +41,19 @@ func vv(size int, pieces ...string) buffer.VectorisedView { return buffer.NewVectorisedView(size, views) } +func pkt(size int, pieces ...string) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv(size, pieces...), + }) +} + type processInput struct { id FragmentID first uint16 last uint16 more bool proto uint8 - vv buffer.VectorisedView + pkt *stack.PacketBuffer } type processOutput struct { @@ -63,8 +70,8 @@ var processTestCases = []struct { { comment: "One ID", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -74,8 +81,8 @@ var processTestCases = []struct { { comment: "Next Header protocol mismatch", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -85,10 +92,10 @@ var processTestCases = []struct { { comment: "Two IDs", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -102,17 +109,17 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, err) } if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView()) + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)", + in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView()) } if done != c.out[i].done { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", @@ -236,11 +243,11 @@ func TestReassemblingTimeout(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -257,17 +264,17 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -281,11 +288,11 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) got := f.size want := 1 @@ -327,6 +334,7 @@ func TestErrors(t *testing.T) { last: 3, more: true, data: "012", + err: ErrInvalidArgs, }, { name: "exact block size with more and too little data", @@ -376,8 +384,8 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -403,14 +411,14 @@ func TestPacketFragmenter(t *testing.T) { tests := []struct { name string - innerMTU int + fragmentPayloadLen uint32 transportHeaderLen int payloadSize int wantFragments []fragmentInfo }{ { name: "Packet exactly fits in MTU", - innerMTU: 1280, + fragmentPayloadLen: 1280, transportHeaderLen: 0, payloadSize: 1280, wantFragments: []fragmentInfo{ @@ -419,7 +427,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet exactly does not fit in MTU", - innerMTU: 1000, + fragmentPayloadLen: 1000, transportHeaderLen: 0, payloadSize: 1001, wantFragments: []fragmentInfo{ @@ -429,7 +437,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a transport header", - innerMTU: 560, + fragmentPayloadLen: 560, transportHeaderLen: 40, payloadSize: 560, wantFragments: []fragmentInfo{ @@ -439,7 +447,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a huge transport header", - innerMTU: 500, + fragmentPayloadLen: 500, transportHeaderLen: 1300, payloadSize: 500, wantFragments: []fragmentInfo{ @@ -458,7 +466,7 @@ func TestPacketFragmenter(t *testing.T) { originalPayload.AppendView(pkt.TransportHeader().View()) originalPayload.Append(pkt.Data) var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) for i := 0; ; i++ { fragPkt, offset, copied, more := pf.BuildNextFragment() wantFragment := test.wantFragments[i] @@ -474,8 +482,8 @@ func TestPacketFragmenter(t *testing.T) { if more != wantFragment.more { t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) } - if got := fragPkt.Size(); got > test.innerMTU { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) } if got := fragPkt.AvailableHeaderBytes(); got != reserve { t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) @@ -497,3 +505,126 @@ func TestPacketFragmenter(t *testing.T) { }) } } + +type testTimeoutHandler struct { + pkt *stack.PacketBuffer +} + +func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + h.pkt = pkt +} + +func TestTimeoutHandler(t *testing.T) { + const ( + proto = 99 + ) + + pk1 := pkt(1, "1") + pk2 := pkt(1, "2") + + type processParam struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + } + + tests := []struct { + name string + params []processParam + wantError bool + wantPkt *stack.PacketBuffer + }{ + { + name: "onTimeout runs", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: pk1, + }, + { + name: "no first fragment", + params: []processParam{ + { + first: 1, + last: 1, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: nil, + }, + { + name: "second pkt is ignored", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + { + first: 0, + last: 0, + more: true, + pkt: pk2, + }, + }, + wantError: false, + wantPkt: pk1, + }, + { + name: "invalid args - first is greater than last", + params: []processParam{ + { + first: 1, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: true, + wantPkt: nil, + }, + } + + id := FragmentID{ID: 0} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + handler := &testTimeoutHandler{pkt: nil} + + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) + + for _, p := range test.params { + if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { + t.Errorf("f.Process error = %s", err) + } + } + if !test.wantError { + r, ok := f.reassemblers[id] + if !ok { + t.Fatal("Reassembler not found") + } + f.release(r, true) + } + switch { + case handler.pkt != nil && test.wantPkt == nil: + t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) + case handler.pkt == nil && test.wantPkt != nil: + t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) + case handler.pkt != nil && test.wantPkt != nil: + if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { + t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) + } + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9bb051a30..9b20bb1d8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -15,19 +15,21 @@ package fragmentation import ( - "container/heap" - "fmt" "math" + "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) type hole struct { - first uint16 - last uint16 - deleted bool + first uint16 + last uint16 + filled bool + final bool + data buffer.View } type reassembler struct { @@ -37,83 +39,139 @@ type reassembler struct { proto uint8 mu sync.Mutex holes []hole - deleted int - heap fragHeap + filled int done bool creationTime int64 + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, - holes: make([]hole, 0, 16), - heap: make(fragHeap, 0, 8), creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ - first: 0, - last: math.MaxUint16, - deleted: false}) + first: 0, + last: math.MaxUint16, + filled: false, + final: true, + }) return r } -// updateHoles updates the list of holes for an incoming fragment and -// returns true iff the fragment filled at least part of an existing hole. -func (r *reassembler) updateHoles(first, last uint16, more bool) bool { - used := false - for i := range r.holes { - if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { - continue - } - used = true - r.deleted++ - r.holes[i].deleted = true - if first > r.holes[i].first { - r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) - } - if last < r.holes[i].last && more { - r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) - } - } - return used -} - -func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() - consumed := 0 if r.done { // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, consumed, nil - } - // For IPv6, it is possible to have different Protocol values between - // fragments of a packet (because, unlike IPv4, the Protocol is not used to - // identify a fragment). In this case, only the Protocol of the first - // fragment must be used as per RFC 8200 Section 4.5. - // - // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded - // here (instead of just the protocol) because most IP options should be - // derived from the first fragment. - if first == 0 { - r.proto = proto + return buffer.VectorisedView{}, 0, false, 0, nil } - if r.updateHoles(first, last, more) { - // We store the incoming packet only if it filled some holes. - heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) - consumed = vv.Size() + + var holeFound bool + var consumed int + for i := range r.holes { + currentHole := &r.holes[i] + + if last < currentHole.first || currentHole.last < first { + continue + } + // For IPv6, overlaps with an existing fragment are explicitly forbidden by + // RFC 8200 section 4.5: + // If any of the fragments being reassembled overlap with any other + // fragments being reassembled for the same packet, reassembly of that + // packet must be abandoned and all the fragments that have been received + // for that packet must be discarded, and no ICMP error messages should be + // sent. + // + // It is not explicitly forbidden for IPv4, but to keep parity with Linux we + // disallow it as well: + // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 + if first < currentHole.first || currentHole.last < last { + // Incoming fragment only partially fits in the free hole. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + } + if !more { + if !currentHole.final || currentHole.filled && currentHole.last != last { + // We have another final fragment, which does not perfectly overlap. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + } + } + + holeFound = true + if currentHole.filled { + // Incoming fragment is a duplicate. + continue + } + + // We are populating the current hole with the payload and creating a new + // hole for any unfilled ranges on either end. + if first > currentHole.first { + r.holes = append(r.holes, hole{ + first: currentHole.first, + last: first - 1, + filled: false, + final: false, + }) + } + if last < currentHole.last && more { + r.holes = append(r.holes, hole{ + first: last + 1, + last: currentHole.last, + filled: false, + final: currentHole.final, + }) + currentHole.final = false + } + v := pkt.Data.ToOwnedView() + consumed = v.Size() r.size += consumed + // Update the current hole to precisely match the incoming fragment. + r.holes[i] = hole{ + first: first, + last: last, + filled: true, + final: currentHole.final, + data: v, + } + r.filled++ + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP + // options received in the first fragment should be used - and they should + // override options from following fragments. + if first == 0 { + r.pkt = pkt + r.proto = proto + } + + break + } + if !holeFound { + // Incoming fragment is beyond end. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict } - // Check if all the holes have been deleted and we are ready to reassamble. - if r.deleted < len(r.holes) { + + // Check if all the holes have been filled and we are ready to reassemble. + if r.filled < len(r.holes) { return buffer.VectorisedView{}, 0, false, consumed, nil } - res, err := r.heap.reassemble() - if err != nil { - return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err) + + sort.Slice(r.holes, func(i, j int) bool { + return r.holes[i].first < r.holes[j].first + }) + + var size int + views := make([]buffer.View, 0, len(r.holes)) + for _, hole := range r.holes { + views = append(views, hole.data) + size += hole.data.Size() } - return res, r.proto, true, consumed, nil + return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index a0a04a027..2ff03eeeb 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -16,92 +16,175 @@ package fragmentation import ( "math" - "reflect" "testing" + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -type updateHolesInput struct { - first uint16 - last uint16 - more bool +type processParams struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + wantDone bool + wantError error } -var holesTestCases = []struct { - comment string - in []updateHolesInput - want []hole -}{ - { - comment: "No fragments. Expected holes: {[0 -> inf]}.", - in: []updateHolesInput{}, - want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, - }, - { - comment: "One fragment at beginning. Expected holes: {[2, inf]}.", - in: []updateHolesInput{{first: 0, last: 1, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: false}, +func TestReassemblerProcess(t *testing.T) { + const proto = 99 + + v := func(size int) buffer.View { + payload := buffer.NewView(size) + for i := 1; i < size; i++ { + payload[i] = uint8(i) * 3 + } + return payload + } + + pkt := func(size int) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v(size).ToVectorisedView(), + }) + } + + var tests = []struct { + name string + params []processParams + want []hole + }{ + { + name: "No fragments", + params: nil, + want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, }, - }, - { - comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", - in: []updateHolesInput{{first: 1, last: 2, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - {first: 3, last: math.MaxUint16, deleted: false}, + { + name: "One fragment at beginning", + params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: math.MaxUint16, filled: false, final: true}, + }, }, - }, - { - comment: "One fragment at the end. Expected holes: {[0, 0]}.", - in: []updateHolesInput{{first: 1, last: 2, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, + { + name: "One fragment in the middle", + params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 0, last: 0, filled: false, final: false}, + {first: 3, last: math.MaxUint16, filled: false, final: true}, + }, }, - }, - { - comment: "One fragment completing a packet. Expected holes: {}.", - in: []updateHolesInput{{first: 0, last: 1, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, + { + name: "One fragment at the end", + params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: true, data: v(2)}, + {first: 0, last: 0, filled: false}, + }, }, - }, - { - comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 1, more: true}, - {first: 2, last: 3, more: false}, + { + name: "One fragment completing a packet", + params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: true, data: v(2)}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: true}, + { + name: "Two fragments completing a packet", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, }, - }, - { - comment: "Two overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 2, more: true}, - {first: 2, last: 3, more: false}, + { + name: "Two fragments completing a packet with a duplicate", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 3, last: math.MaxUint16, deleted: true}, + { + name: "Two fragments completing a packet with a partial duplicate", + params: []processParams{ + {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, + {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 3, filled: true, final: false, data: v(4)}, + {first: 4, last: 5, filled: true, final: true, data: v(2)}, + }, }, - }, -} + { + name: "Two overlapping fragments", + params: []processParams{ + {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, + {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, + }, + want: []hole{ + {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 11, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "Two final fragments with different ends", + params: []processParams{ + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 0, last: 9, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate, with different ends", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, + }, + } -func TestUpdateHoles(t *testing.T) { - for _, c := range holesTestCases { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - for _, i := range c.in { - r.updateHoles(i.first, i.last, i.more) - } - if !reflect.DeepEqual(r.holes, c.want) { - t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + for _, param := range test.params { + _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if done != param.wantDone || err != param.wantError { + t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) + } + } + if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD new file mode 100644 index 000000000..ca1247c1e --- /dev/null +++ b/pkg/tcpip/network/ip/BUILD @@ -0,0 +1,26 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ip", + srcs = ["generic_multicast_protocol.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + ], +) + +go_test( + name = "ip_test", + size = "small", + srcs = ["generic_multicast_protocol_test.go"], + deps = [ + ":ip", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/faketime", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go new file mode 100644 index 000000000..f85c5ff9d --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -0,0 +1,671 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ip holds IPv4/IPv6 common utilities. +package ip + +import ( + "fmt" + "math/rand" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// hostState is the state a host may be in for a multicast group. +type hostState int + +// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 +// (RFC 2710 section 5). Even though the states are generic across both IGMPv2 +// and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. +const ( + // nonMember is the "'Non-Member' state, when the host does not belong to the + // group on the interface. This is the initial state for all memberships on + // all network interfaces; it requires no storage in the host." + // + // 'Non-Listener' is the MLDv1 term used to describe this state. + // + // This state is used to keep track of groups that have been joined locally, + // but without advertising the membership to the network. + nonMember hostState = iota + + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + + // delayingMember is the "'Delaying Member' state, when the host belongs to + // the group on the interface and has a report delay timer running for that + // membership." + // + // 'Delaying Listener' is the MLDv1 term used to describe this state. + delayingMember + + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + + // idleMember is the "Idle Member" state, when the host belongs to the group + // on the interface and does not have a report delay timer running for that + // membership. + // + // 'Idle Listener' is the MLDv1 term used to describe this state. + idleMember +) + +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + +// multicastGroupState holds the Generic Multicast Protocol state for a +// multicast group. +type multicastGroupState struct { + // joins is the number of times the group has been joined. + joins uint64 + + // state holds the host's state for the group. + state hostState + + // lastToSendReport is true if we sent the last report for the group. It is + // used to track whether there are other hosts on the subnet that are also + // members of the group. + // + // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page + // 8 for MLDv1. + lastToSendReport bool + + // delayedReportJob is used to delay sending responses to membership report + // messages in order to reduce duplicate reports from multiple hosts on the + // interface. + // + // Must not be nil. + delayedReportJob *tcpip.Job +} + +// GenericMulticastProtocolOptions holds options for the generic multicast +// protocol. +type GenericMulticastProtocolOptions struct { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled bool + + // Rand is the source of random numbers. + Rand *rand.Rand + + // Clock is the clock used to create timers. + Clock tcpip.Clock + + // Protocol is the implementation of the variant of multicast group protocol + // in use. + Protocol MulticastGroupProtocol + + // MaxUnsolicitedReportDelay is the maximum amount of time to wait between + // transmitting unsolicited reports. + // + // Unsolicited reports are transmitted when a group is newly joined. + MaxUnsolicitedReportDelay time.Duration + + // AllNodesAddress is a multicast address that all nodes on a network should + // be a member of. + // + // This address will not have the generic multicast protocol performed on it; + // it will be left in the non member/listener state, and packets will never + // be sent for it. + AllNodesAddress tcpip.Address +} + +// MulticastGroupProtocol is a multicast group protocol whose core state machine +// can be represented by GenericMulticastProtocolState. +type MulticastGroupProtocol interface { + // SendReport sends a multicast report for the specified group address. + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) + + // SendLeave sends a multicast leave for the specified group address. + SendLeave(groupAddress tcpip.Address) *tcpip.Error +} + +// GenericMulticastProtocolState is the per interface generic multicast protocol +// state. +// +// There is actually no protocol named "Generic Multicast Protocol". Instead, +// the term used to refer to a generic multicast protocol that applies to both +// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state +// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. +// +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// +// GenericMulticastProtocolState.Init MUST be called before calling any of +// the methods on GenericMulticastProtocolState. +type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + + opts GenericMulticastProtocolOptions + + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState + + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex +} + +// Init initializes the Generic Multicast Protocol state. +// +// Must only be called once for the lifetime of g; Init will panic if it is +// called twice. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + + *g = GenericMulticastProtocolState{ + opts: opts, + memberships: make(map[tcpip.Address]multicastGroupState), + protocolMU: protocolMU, + } +} + +// MakeAllNonMemberLocked transitions all groups to the non-member state. +// +// The groups will still be considered joined locally. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { + if !g.opts.Enabled { + return + } + + for groupAddress, info := range g.memberships { + g.transitionToNonMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. +// +// Must only be called after calling MakeAllNonMember as a group should not be +// initialized while it is not in the non-member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { + if !g.opts.Enabled { + return + } + + for groupAddress, info := range g.memberships { + g.initializeNewMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} + +// JoinGroupLocked handles joining a new group. +// +// If dontInitialize is true, the group will be not be initialized and will be +// left in the non-member state - no packets will be sent for it until it is +// initialized via InitializeGroups. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) { + if info, ok := g.memberships[groupAddress]; ok { + // The group has already been joined. + info.joins++ + g.memberships[groupAddress] = info + return + } + + info := multicastGroupState{ + // Since we just joined the group, its count is 1. + joins: 1, + // The state will be updated below, if required. + state: nonMember, + lastToSendReport: false, + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + info, ok := g.memberships[groupAddress] + if !ok { + panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) + } + + g.maybeSendDelayedReportLocked(groupAddress, &info) + g.memberships[groupAddress] = info + }), + } + + if !dontInitialize && g.opts.Enabled { + g.initializeNewMemberLocked(groupAddress, &info) + } + + g.memberships[groupAddress] = info +} + +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] + return ok +} + +// LeaveGroupLocked handles leaving the group. +// +// Returns false if the group is not currently joined. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] + if !ok { + return false + } + + if info.joins == 0 { + panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress)) + } + info.joins-- + if info.joins != 0 { + // If we still have outstanding joins, then do nothing further. + g.memberships[groupAddress] = info + return true + } + + g.transitionToNonMemberLocked(groupAddress, &info) + delete(g.memberships, groupAddress) + return true +} + +// HandleQueryLocked handles a query message with the specified maximum response +// time. +// +// If the group address is unspecified, then reports will be scheduled for all +// joined groups. +// +// Report(s) will be scheduled to be sent after a random duration between 0 and +// the maximum response time. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Enabled { + return + } + + // As per RFC 2236 section 2.4 (for IGMPv2), + // + // In a Membership Query message, the group address field is set to zero + // when sending a General Query, and set to the group address being + // queried when sending a Group-Specific Query. + // + // As per RFC 2710 section 3.6 (for MLDv1), + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + if groupAddress.Unspecified() { + // This is a general query as the group address is unspecified. + for groupAddress, info := range g.memberships { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } + } else if info, ok := g.memberships[groupAddress]; ok { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } +} + +// HandleReportLocked handles a report message. +// +// If the report is for a joined group, any active delayed report will be +// cancelled and the host state for the group transitions to idle. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { + if !g.opts.Enabled { + return + } + + // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), + // + // If the host receives another host's Report (version 1 or 2) while it has + // a timer running, it stops its timer for the specified group and does not + // send a Report + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // If a node receives another node's Report from an interface for a + // multicast address while it has a timer running for that same address + // on that interface, it stops its timer and does not send a Report for + // that address, thus suppressing duplicate reports on the link. + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { + info.delayedReportJob.Cancel() + info.lastToSendReport = false + info.state = idleMember + g.memberships[groupAddress] = info + } +} + +// initializeNewMemberLocked initializes a new group membership. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != nonMember { + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) + } + + info.lastToSendReport = false + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + info.state = idleMember + return + } + + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a host joins a multicast group, it should immediately transmit an + // unsolicited Version 2 Membership Report for that group" ... "it is + // recommended that it be repeated". + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node starts listening to a multicast address on an interface, + // it should immediately transmit an unsolicited Report for that address + // on that interface, in case it is the first listener on the link. To + // cover the possibility of the initial Report being lost or damaged, it + // is recommended that it be repeated once or twice after short delays + // [Unsolicited Report Interval]. + // + // TODO(gvisor.dev/issue/4901): Support a configurable number of initial + // unsolicited reports. + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } +} + +// maybeSendLeave attempts to send a leave message. +func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { + if !g.opts.Enabled || !lastToSendReport { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // Okay to ignore the error here as if packet write failed, the multicast + // routers will eventually drop our membership anyways. If the interface is + // being disabled or removed, the generic multicast protocol's should be + // cleared eventually. + // + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a router receives a Report, it adds the group being reported to + // the list of multicast group memberships on the network on which it + // received the Report and sets the timer for the membership to the + // [Group Membership Interval]. Repeated Reports refresh the timer. If + // no Reports are received for a particular group before this timer has + // expired, the router assumes that the group has no local members and + // that it need not forward remotely-originated multicasts for that + // group onto the attached network. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // When a router receives a Report from a link, if the reported address + // is not already present in the router's list of multicast address + // having listeners on that link, the reported address is added to the + // list, its timer is set to [Multicast Listener Interval], and its + // appearance is made known to the router's multicast routing component. + // If a Report is received for a multicast address that is already + // present in the router's list, the timer for that address is reset to + // [Multicast Listener Interval]. If an address's timer expires, it is + // assumed that there are no longer any listeners for that address + // present on the link, so it is deleted from the list and its + // disappearance is made known to the multicast routing component. + // + // The requirement to send a leave message is also optional (it MAY be + // skipped): + // + // As per RFC 2236 section 6 page 8 (for IGMPv2), + // + // "send leave" for the group on the interface. If the interface + // state says the Querier is running IGMPv1, this action SHOULD be + // skipped. If the flag saying we were the last host to report is + // cleared, this action MAY be skipped. The Leave Message is sent to + // the ALL-ROUTERS group (224.0.0.2). + // + // As per RFC 2710 section 5 page 8 (for MLDv1), + // + // "send done" for the address on the interface. If the flag saying + // we were the last node to report is cleared, this action MAY be + // skipped. The Done message is sent to the link-scope all-routers + // address (FF02::2). + _ = g.opts.Protocol.SendLeave(groupAddress) +} + +// transitionToNonMemberLocked transitions the given multicast group the the +// non-member/listener state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state == nonMember { + return + } + + info.delayedReportJob.Cancel() + g.maybeSendLeave(groupAddress, info.lastToSendReport) + info.lastToSendReport = false + info.state = nonMember +} + +// setDelayTimerForAddressRLocked sets timer to send a delay report. +// +// Precondition: g.protocolMU MUST be read locked. +func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { + if info.state == nonMember { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // If a timer for the group is already unning, it is reset to the random + // value only if the requested Max Response Time is less than the remaining + // value of the running timer. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // If a timer for any address is already running, it is reset to the new + // random value only if the requested Maximum Response Delay is less than + // the remaining value of the running timer. + if info.state == delayingMember { + // TODO: Reset the timer if time remaining is greater than maxResponseTime. + return + } + + info.state = delayingMember + info.delayedReportJob.Cancel() + info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) +} + +// calculateDelayTimerDuration returns a random time between (0, maxRespTime]. +func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration { + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // When a host receives a Group-Specific Query, it sets a delay timer to a + // random value selected from the range (0, Max Response Time]... + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node receives a Multicast-Address-Specific Query, if it is + // listening to the queried Multicast Address on the interface from + // which the Query was received, it sets a delay timer for that address + // to a random value selected from the range [0, Maximum Response Delay], + // as above. + if maxRespTime == 0 { + return 0 + } + return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime))) +} diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go new file mode 100644 index 000000000..95040515c --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -0,0 +1,884 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" +) + +const ( + addr1 = tcpip.Address("\x01") + addr2 = tcpip.Address("\x02") + addr3 = tcpip.Address("\x03") + addr4 = tcpip.Address("\x04") + + maxUnsolicitedReportDelay = time.Second +) + +var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) + +type mockMulticastGroupProtocol struct { + t *testing.T + + mu sync.RWMutex + + // Must only be accessed with mu held. + sendReportGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + sendLeaveGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + makeQueuePackets bool +} + +func (m *mockMulticastGroupProtocol) init() { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() +} + +func (m *mockMulticastGroupProtocol) initLocked() { + m.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +} + +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + + m.sendReportGroupAddrCount[groupAddress]++ + return !m.makeQueuePackets, nil +} + +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + + m.sendLeaveGroupAddrCount[groupAddress]++ + return nil +} + +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendReportGroupAddresses { + sendReportGroupAddrCount[a] = 1 + } + + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendLeaveGroupAddresses { + sendLeaveGroupAddrCount[a] = 1 + } + + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + return p.Last().String() == ".mu" || p.Last().String() == ".t" || p.Last().String() == ".makeQueuePackets" + }, + cmp.Ignore(), + ), + ) + m.initLocked() + return diff +} + +func TestJoinGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendReports bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendReports: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendReports: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(0)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + // Joining a group should send a report immediately and another after + // a random interval between 0 and the maximum unsolicited report delay. + mgp.mu.Lock() + g.JoinGroupLocked(test.addr, false /* dontInitialize */) + mgp.mu.Unlock() + if test.shouldSendReports { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLeaveGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendMessages bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendMessages: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendMessages: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(1)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(test.addr, false /* dontInitialize */) + mgp.mu.Unlock() + if test.shouldSendMessages { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Leaving a group should send a leave report immediately and cancel any + // delayed reports. + { + mgp.mu.Lock() + res := g.LeaveGroupLocked(test.addr) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr) + } + } + if test.shouldSendMessages { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleReport(t *testing.T) { + tests := []struct { + name string + reportAddr tcpip.Address + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + reportAddr: "", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + reportAddr: "\x00", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + reportAddr: addr1, + expectReportsFor: []tcpip.Address{addr2}, + }, + { + name: "Specified all-nodes", + reportAddr: addr3, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified other", + reportAddr: addr4, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(2)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report for a group we have a timer scheduled for should + // cancel our delayed report timer for the group. + mgp.mu.Lock() + g.HandleReportLocked(test.reportAddr) + mgp.mu.Unlock() + if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleQuery(t *testing.T) { + tests := []struct { + name string + queryAddr tcpip.Address + maxDelay time.Duration + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + queryAddr: "", + maxDelay: 0, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + queryAddr: "\x00", + maxDelay: 1, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + queryAddr: addr1, + maxDelay: 2, + expectReportsFor: []tcpip.Address{addr1}, + }, + { + name: "Specified all-nodes", + queryAddr: addr3, + maxDelay: 3, + expectReportsFor: nil, + }, + { + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectReportsFor: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should make us schedule a new delayed report if it + // is a query directed at us or a general query. + mgp.mu.Lock() + g.HandleQueryLocked(test.queryAddr, test.maxDelay) + mgp.mu.Unlock() + if len(test.expectReportsFor) != 0 { + clock.Advance(test.maxDelay) + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestJoinCount(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: time.Second, + }) + + // Set the join count to 2 for a group. + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + // Only the first join should trigger a report to be sent. + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should still be considered joined after leaving once. + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if !isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + // A leave report should only be sent once the join count reaches 0. + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Leaving once more should actually remove us from the group. + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should no longer be joined so we should not have anything to + // leave. + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestMakeAllNonMemberAndInitialize(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should send the leave reports for each but still consider them locally + // joined. + mgp.mu.Lock() + g.MakeAllNonMemberLocked() + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + for _, group := range []tcpip.Address{addr1, addr2, addr3} { + mgp.mu.RLock() + res := g.IsLocallyJoinedRLocked(group) + mgp.mu.RUnlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group) + } + } + + // Should send the initial set of unsolcited reports. + mgp.mu.Lock() + g.InitializeGroupsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +// TestGroupStateNonMember tests that groups do not send packets when in the +// non-member state, but are still considered locally joined. +func TestGroupStateNonMember(t *testing.T) { + tests := []struct { + name string + enabled bool + dontInitialize bool + }{ + { + name: "Disabled", + enabled: false, + dontInitialize: false, + }, + { + name: "Keep non-member", + enabled: true, + dontInitialize: true, + }, + { + name: "disabled and Keep non-member", + enabled: false, + dontInitialize: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var g ip.GenericMulticastProtocolState + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: test.enabled, + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining groups should not send any reports. + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, test.dontInitialize) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, test.dontInitialize) + res := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should not send any reports. + mgp.mu.Lock() + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Nanosecond) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Leaving groups should not send any leave messages. + { + mgp.mu.Lock() + addr2LeaveRes := g.LeaveGroupLocked(addr2) + addr1IsJoined := g.IsLocallyJoinedRLocked(addr1) + addr2IsJoined := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !addr2LeaveRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2) + } + if !addr1IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + if addr2IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestQueuedPackets(t *testing.T) { + var g ip.GenericMulticastProtocolState + var mgp mockMulticastGroupProtocol + mgp.init() + clock := faketime.NewManualClock() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ + Enabled: true, + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + Protocol: &mgp, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.mu.Lock() + g.HandleReportLocked(addr1) + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.mu.Lock() + mgp.makeQueuePackets = true + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.mu.Lock() + g.HandleReportLocked(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.mu.Lock() + mgp.makeQueuePackets = false + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.mu.Lock() + g.SendQueuedReportsLocked() + mgp.mu.Unlock() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f20b94d97..3005973d7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -34,16 +35,16 @@ import ( ) const ( - localIPv4Addr = "\x0a\x00\x00\x01" - remoteIPv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") + ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") + ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") + ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") + localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") + ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") nicID = 1 ) @@ -110,8 +111,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { + netHdr := pkt.Network() + t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) t.dataCalls++ return stack.TransportPacketHandled } @@ -191,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu panic("not implemented") } -func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") @@ -205,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net panic("not implemented") } -func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -221,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } -func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -298,12 +296,20 @@ func (t *testInterface) Enabled() bool { return !t.mu.disabled } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) setEnabled(v bool) { t.mu.Lock() defer t.mu.Unlock() t.mu.disabled = !v } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -315,7 +321,6 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetChecksum(^header.Checksum(pkt, 0)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(totalLen), Protocol: uint8(icmp.ProtocolNumber4), TTL: ipv4.DefaultTTL, @@ -339,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -545,7 +550,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -554,59 +559,135 @@ func TestIPv4Send(t *testing.T) { } } -func TestIPv4Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, +func TestReceive(t *testing.T) { + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + v4 bool + epAddr tcpip.AddressWithPrefix + handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + v4: true, + epAddr: localIPv4Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const totalLen = header.IPv4MinimumSize + 30 /* payload length */ + + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + TTL: ipv4.DefaultTTL, + Protocol: 10, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if ok := parse.IPv4(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + v4: false, + epAddr: localIPv6Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const payloadLen = 30 + view := buffer.NewView(header.IPv6MinimumSize + payloadLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadLen, + TransportProtocol: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, + }) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, _, _, ok := parse.IPv6(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, + }, } - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: test.v4, + }, + } + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - ep.HandlePacket(&r, pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) + } + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + } else { + ep.DecRef() + } + + stat := s.Stats().IP.PacketsReceived + if got := stat.Value(); got != 0 { + t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) + } + test.handlePacket(t, ep, &nic) + if nic.testObject.dataCalls != 1 { + t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + } + if got := stat.Value(); got != 1 { + t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) + } + }) } } @@ -630,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -656,7 +733,6 @@ func TestIPv4ReceiveControl(t *testing.T) { // Create the outer IPv4 header. ip := header.IPv4(view) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(len(view) - c.trunc), TTL: 20, Protocol: uint8(header.ICMPv4ProtocolNumber), @@ -675,7 +751,6 @@ func TestIPv4ReceiveControl(t *testing.T) { // Create the inner IPv4 header. ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: 100, TTL: 20, Protocol: 10, @@ -690,6 +765,10 @@ func TestIPv4ReceiveControl(t *testing.T) { view[i] = uint8(i) } + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) + // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. nic.testObject.protocol = 10 @@ -699,7 +778,19 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -708,7 +799,9 @@ func TestIPv4ReceiveControl(t *testing.T) { } func TestIPv4FragmentationReceive(t *testing.T) { - s := buildDummyStack(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ testObject: testObject{ @@ -728,7 +821,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { frag1 := buffer.NewView(totalLen) ip1 := header.IPv4(frag1) ip1.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(totalLen), TTL: 20, Protocol: 10, @@ -747,7 +839,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { frag2 := buffer.NewView(totalLen) ip2 := header.IPv4(frag2) ip2.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(totalLen), TTL: 20, Protocol: 10, @@ -768,11 +859,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - // Send first segment. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), @@ -780,7 +866,19 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } @@ -792,7 +890,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -835,7 +933,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -844,60 +942,6 @@ func TestIPv6Send(t *testing.T) { } } -func TestIPv6Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - ep.HandlePacket(&r, pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } -} - func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } @@ -924,13 +968,6 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } - r, err := buildIPv6Route( - localIPv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -956,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) { // Create the outer IPv6 header. ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 20, + SrcAddr: outerSrcAddr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -970,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp.SetIdent(0xdead) icmp.SetSequence(0xbeef) - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - }) - + var extHdrs header.IPv6ExtHdrSerializer // Build the fragmentation header if needed. if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ FragmentOffset: *c.fragmentOffset, M: true, Identification: 0x12345678, }) } + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, + ExtensionHeaders: extHdrs, + }) + // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { view[i] = uint8(i) @@ -1009,7 +1045,18 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv6Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -1035,15 +1082,25 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicID = 1 transportProto = 5 - dataLen = 4 - optionsLen = 4 + dataLen = 4 ) dataBuf := [dataLen]byte{1, 2, 3, 4} data := dataBuf[:] - ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1} - ipv4Options := ipv4OptionsBuf[:] + ipv4Options := header.IPv4OptionsSerializer{ + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + } + + expectOptions := header.IPv4Options{ + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + } ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] @@ -1063,7 +1120,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum tcpip.NetworkProtocolNumber nicAddr tcpip.Address remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.View + pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) expectedErr *tcpip.Error }{ @@ -1073,7 +1130,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1081,13 +1138,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1115,7 +1171,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1123,13 +1179,13 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize - 1, Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + ip.SetHeaderLength(header.IPv4MinimumSize - 1) + return hdr.View().ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1139,16 +1195,15 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1158,16 +1213,15 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1195,8 +1249,8 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { - ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { + ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1204,16 +1258,54 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv4(hdr.Prepend(ipHdrLen)) ip.Encode(&header.IPv4Fields{ - IHL: uint8(ipHdrLen), Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, DstAddr: header.IPv4Any, + Options: ipv4Options, }) - if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { - t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) + return hdr.View().ToVectorisedView() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } - return hdr.View() + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(expectOptions), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with options and data across views", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) + ip.Encode(&header.IPv4Fields{ + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + Options: ipv4Options, + }) + vv := buffer.View(ip).ToVectorisedView() + vv.AppendView(data) + return vv }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1222,7 +1314,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { netHdr := pkt.NetworkHeader() - hdrLen := header.IPv4MinimumSize + len(ipv4Options) + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) if len(netHdr.View()) != hdrLen { t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } @@ -1232,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { checker.DstAddr(remoteIPv4Addr), checker.IPv4HeaderLength(hdrLen), checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(ipv4Options), + checker.IPv4Options(expectOptions), checker.IPPayload(data), ) }, @@ -1243,7 +1335,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1251,12 +1343,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1283,7 +1375,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1294,12 +1386,14 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + // NB: we're lying about transport protocol here to verify the raw + // fragment header bytes. + TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1326,15 +1420,15 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1361,15 +1455,15 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1413,7 +1507,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { defer r.Release() if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + Data: test.pktGen(t, subTest.srcAddr), })); err != test.expectedErr { t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7fc12e229..32f53f217 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -6,6 +6,7 @@ go_library( name = "ipv4", srcs = [ "icmp.go", + "igmp.go", "ipv4.go", ], visibility = ["//visibility:public"], @@ -17,6 +18,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -24,11 +26,15 @@ go_library( go_test( name = "ipv4_test", size = "small", - srcs = ["ipv4_test.go"], + srcs = [ + "igmp_test.go", + "ipv4_test.go", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3407755ed..8e392f86c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -23,10 +24,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// handleControl handles the case when an ICMP packet contains the headers of -// the original packet that caused the ICMP one to be sent. This information is -// used to find out which transport endpoint must be notified about the ICMP -// packet. +// handleControl handles the case when an ICMP error packet contains the headers +// of the original packet that caused the ICMP one to be sent. This information +// is used to find out which transport endpoint must be notified about the ICMP +// packet. We only expect the payload, not the enclosing ICMP packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { @@ -41,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. - src := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { + srcAddr := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 { return } @@ -57,12 +58,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { - stats := r.Stats() - received := stats.ICMP.V4PacketsReceived +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + received := stats.ICMP.V4.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } h := header.ICMPv4(v) + // Only do in-stack processing if the checksum is correct. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + received.Invalid.Increment() + // It's possible that a raw socket expects to receive this regardless + // of checksum errors. If it's an echo request we know it's safe because + // we are the only handler, however other types do not cope well with + // packets with checksum errors. + switch h.Type() { + case header.ICMPv4Echo: + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + } + return + } + + iph := header.IPv4(pkt.NetworkHeader().View()) + var newOptions header.IPv4Options + if opts := iph.Options(); len(opts) != 0 { + // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip + // type ICMP packets): + // If a Record Route and/or Time Stamp option is received in an + // ICMP Echo Request, this option (these options) SHOULD be + // updated to include the current host and included in the IP + // header of the Echo Reply message, without "truncation". + // Thus, the recorded route will be for the entire round trip. + // + // So we need to let the option processor know how it should handle them. + var op optionsUsage + if h.Type() == header.ICMPv4Echo { + op = &optionUsageEcho{} + } else { + op = &optionUsageReceive{} + } + aux, tmp, err := e.processIPOptions(pkt, opts, op) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() + } + return + } + newOptions = tmp + } + // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - // Only send a reply if the checksum is valid. - headerChecksum := h.Checksum() - h.SetChecksum(0) - calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(headerChecksum) - if calculatedChecksum != headerChecksum { - // It's possible that a raw socket still expects to receive this. - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - received.Invalid.Increment() + sent := stats.ICMP.V4.PacketsSent + if !e.protocol.stack.AllowICMPMessage() { + sent.RateLimited.Increment() return } @@ -98,19 +144,27 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. replyData := pkt.Data.ToOwnedView() - replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast + + // It's possible that a raw socket expects to receive this. + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + pkt = nil - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + // Take the base of the incoming request IP header but replace the options. + replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) + replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) + replyIPHdr.SetHeaderLength(replyHeaderLength) // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). - localAddr := r.LocalAddress - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { + localAddr := ipHdr.DestinationAddress() + if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -139,7 +193,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // The fields we need to alter. // // We need to produce the entire packet in the data segment in order to - // use WriteHeaderIncludedPacket(). + // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the + // total length and the header checksum so we don't need to set those here. replyIPHdr.SetSourceAddress(r.LocalAddress) replyIPHdr.SetDestinationAddress(r.RemoteAddress) replyIPHdr.SetTTL(r.DefaultTTL()) @@ -157,8 +212,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { }) replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // The checksum will be calculated so we don't need to do it here. - sent := stats.ICMP.V4PacketsSent if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return @@ -168,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() @@ -182,8 +235,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: - mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) } case header.ICMPv4SrcQuench: @@ -234,12 +290,38 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in +// transit to its final destination, as per RFC 792 page 6, Time Exceeded +// Message. +type icmpReasonTTLExceeded struct{} + +func (*icmpReasonTTLExceeded) isICMPReason() {} + +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + +// icmpReasonParamProblem is an error to use to request a Parameter Problem +// message to be sent. +type icmpReasonParamProblem struct { + pointer byte +} + +func (*icmpReasonParamProblem) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv4(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -263,35 +345,50 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in // response to a non-initial fragment, but it currently can not happen. - - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any { return nil } + // If we hit a TTL Exceeded error, then we know we are operating as a router. + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + // + // ... + // + // Code 0 may be received from a gateway. ... + // + // Note, Code 0 is the TTL exceeded error. + // + // If we are operating as a router/gateway, don't use the packet's destination + // address as the response's source address as we should not not own the + // destination address of a packet we are forwarding. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonTTLExceeded); ok { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil - sent := p.stack.Stats().ICMP.V4PacketsSent + sent := p.stack.Stats().ICMP.V4.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil } - networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. - if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { // TODO(gvisor.dev/issue/3810): // Unfortunately the current stack pretty much always has ICMPv4 headers // in the Data section of the packet but there is no guarantee that is the @@ -348,7 +445,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -360,7 +457,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // view with the entire incoming IP packet reassembled and truncated as // required. This is now the payload of the new ICMP packet and no longer // considered a packet in its own right. - newHeader := append(buffer.View(nil), networkHeader...) + newHeader := append(buffer.View(nil), origIPHdr...) newHeader = append(newHeader, transportHeader...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) @@ -374,17 +471,33 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - switch reason.(type) { + var counter *tcpip.StatCounter + switch reason := reason.(type) { case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) + counter = sent.DstUnreachable case *icmpReasonProtoUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + counter = sent.DstUnreachable + case *icmpReasonTTLExceeded: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4TTLExceeded) + counter = sent.TimeExceeded + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) + counter = sent.TimeExceeded + case *icmpReasonParamProblem: + icmpHdr.SetType(header.ICMPv4ParamProblem) + icmpHdr.SetCode(header.ICMPv4UnusedCode) + icmpHdr.SetPointer(reason.pointer) + counter = sent.ParamProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - counter := sent.DstUnreachable if err := route.WritePacket( nil, /* gso */ @@ -401,3 +514,18 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards the + // datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go new file mode 100644 index 000000000..fb7a9e68e --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -0,0 +1,344 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "fmt" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // igmpV1PresentDefault is the initial state for igmpV1Present in the + // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is + // the initial state." + igmpV1PresentDefault = 0 + + // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18 + // See note on igmpState.igmpV1Present for more detail. + v1RouterPresentTimeout = 400 * time.Second + + // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router + // will send General Queries with the Max Response Time set to 0. This MUST + // be interpreted as a value of 100 (10 seconds)." + // + // Note that the Max Response Time field is a value in units of deciseconds. + v1MaxRespTime = 10 * time.Second + + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited IGMP reports. + // + // Obtained from RFC 2236 Section 8.10, Page 19. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// IGMPOptions holds options for IGMP. +type IGMPOptions struct { + // Enabled indicates whether IGMP will be performed. + // + // When enabled, IGMP may transmit IGMP report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // IGMP packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*igmpState)(nil) + +// igmpState is the per-interface IGMP state. +// +// igmpState.init() MUST be called after creating an IGMP state. +type igmpState struct { + // The IPv4 endpoint this igmpState is for. + ep *endpoint + + enabled bool + + genericMulticastProtocol ip.GenericMulticastProtocolState + + // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from + // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 + // Membership Reports in response to its Queries, and will not pay + // attention to Version 2 Membership Reports. Therefore, a state variable + // MUST be kept for each interface, describing whether the multicast + // Querier on that interface is running IGMPv1 or IGMPv2. This variable + // MUST be based upon whether or not an IGMPv1 query was heard in the last + // [Version 1 Router Present Timeout] seconds". + // + // Must be accessed with atomic operations. Holds a value of 1 when true, 0 + // when false. + igmpV1Present uint32 + + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + igmpType := header.IGMPv2MembershipReport + if igmp.v1Present() { + igmpType = header.IGMPv1MembershipReport + } + return igmp.writePacket(groupAddress, groupAddress, igmpType) +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + // As per RFC 2236 Section 6, Page 8: "If the interface state says the + // Querier is running IGMPv1, this action SHOULD be skipped. If the flag + // saying we were the last host to report is cleared, this action MAY be + // skipped." + if igmp.v1Present() { + return nil + } + _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + return err +} + +// init sets up an igmpState struct, and is required to be called before using +// a new igmpState. +// +// Must only be called once for the lifetime of igmp. +func (igmp *igmpState) init(ep *endpoint) { + igmp.ep = ep + // No need to perform IGMP on loopback interfaces since they don't have + // neighbouring nodes. + igmp.enabled = ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() + igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Enabled: igmp.enabled, + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: igmp, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv4AllSystems, + }) + igmp.igmpV1Present = igmpV1PresentDefault + igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { + igmp.setV1Present(false) + }) +} + +// handleIGMP handles an IGMP packet. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { + stats := igmp.ep.protocol.stack.Stats() + received := stats.IGMP.PacketsReceived + headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.IGMP(headerView) + + // Temporarily reset the checksum field to 0 in order to calculate the proper + // checksum. + wantChecksum := h.Checksum() + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(wantChecksum) + + if gotChecksum != wantChecksum { + received.ChecksumErrors.Increment() + return + } + + switch h.Type() { + case header.IGMPMembershipQuery: + received.MembershipQuery.Increment() + if len(headerView) < header.IGMPQueryMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) + case header.IGMPv1MembershipReport: + received.V1MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPv2MembershipReport: + received.V2MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPLeaveGroup: + received.LeaveGroup.Increment() + // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or + // Report, are ignored in all states" + + default: + // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should + // be silently ignored. New message types may be used by newer versions of + // IGMP, by multicast routing protocols, or other uses." + received.Unrecognized.Increment() + } +} + +func (igmp *igmpState) v1Present() bool { + return atomic.LoadUint32(&igmp.igmpV1Present) == 1 +} + +func (igmp *igmpState) setV1Present(v bool) { + if v { + atomic.StoreUint32(&igmp.igmpV1Present, 1) + } else { + atomic.StoreUint32(&igmp.igmpV1Present, 0) + } +} + +// handleMembershipQuery handles a membership query. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { + // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero + // then change the state to note that an IGMPv1 router is present and + // schedule the query received Job. + if igmp.enabled && maxRespTime == 0 { + igmp.igmpV1Job.Cancel() + igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) + igmp.setV1Present(true) + maxRespTime = v1MaxRespTime + } + + igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime) +} + +// handleMembershipReport handles a membership report. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { + igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) +} + +// writePacket assembles and sends an IGMP packet. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { + igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) + igmpData.SetType(igmpType) + igmpData.SetGroupAddress(groupAddress) + igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()), + Data: buffer.View(igmpData).ToVectorisedView(), + }) + + addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) + if addressEndpoint == nil { + return false, nil + } + localAddr := addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() + addressEndpoint = nil + igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.IGMPProtocolNumber, + TTL: header.IGMPTTL, + TOS: stack.DefaultTOS, + }, header.IPv4OptionsSerializer{ + &header.IPv4SerializableRouterAlertOption{}, + }) + + sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return false, err + } + switch igmpType { + case header.IGMPv1MembershipReport: + sentStats.V1MembershipReport.Increment() + case header.IGMPv2MembershipReport: + sentStats.V2MembershipReport.Increment() + case header.IGMPLeaveGroup: + sentStats.LeaveGroup.Increment() + default: + panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) + } + return true, nil +} + +// joinGroup handles adding a new group to the membership map, setting up the +// IGMP state for the group, and sending and scheduling the required +// messages. +// +// If the group already exists in the membership map, returns +// tcpip.ErrDuplicateAddress. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) +} + +// isInGroup returns true if the specified group has been joined locally. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { + return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Leave Group message +// if required. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of IGMP, but remains +// joined locally. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) softLeaveAll() { + igmp.genericMulticastProtocol.MakeAllNonMemberLocked() +} + +// initializeAll attemps to initialize the IGMP state for each group that has +// been joined locally. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) initializeAll() { + igmp.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) sendQueuedReports() { + igmp.genericMulticastProtocol.SendQueuedReportsLocked() +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go new file mode 100644 index 000000000..1ee573ac8 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -0,0 +1,215 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + addr = tcpip.Address("\x0a\x00\x00\x01") + multicastAddr = tcpip.Address("\xe0\x00\x00\x03") + nicID = 1 +) + +// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. Raises a t.Error if +// any field does not match. +func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.SrcAddr(addr), + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(igmpType), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + // Create an endpoint of queue size 1, since no more than 1 packets are ever + // queued in the tests in this file. + e := channel.New(1, 1280, linkAddr) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + return e, s, clock +} + +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: 1, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(igmpType) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is +// present on the network. The IGMP stack will then send IGMPv1 Membership +// reports for backwards compatibility. +func TestIgmpV1Present(t *testing.T) { + e, s, clock := createStack(t, true) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // This NIC will send an IGMPv2 report immediately, before this test can get + // the IGMPv1 General Membership Query in. + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject an IGMPv1 General Membership Query which is identical to a standard + // membership query except the Max Response Time is set to 0, which will tell + // the stack that this is a router using IGMPv1. Send it to the all systems + // group which is the only group this host belongs to. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems) + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { + t.Fatalf("got Membership Queries received = %d, want = 1", got) + } + + // Before advancing the clock, verify that this host has not sent a + // V1MembershipReport yet. + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) + } + + // Verify the solicited Membership Report is sent. Now that this NIC has seen + // an IGMPv1 query, it should send an IGMPv1 Membership Report. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) +} + +func TestSendQueuedIGMPReports(t *testing.T) { + e, s, clock := createStack(t, true) + + // Joining a group without an assigned address should queue IGMP packets; none + // should be sent without an assigned address. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) + } + reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport + if got := reportStat.Value(); got != 0 { + t.Errorf("got reportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } + + // The initial set of IGMP reports that were queued should be sent once an + // address is assigned. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + if got := reportStat.Value(); got != 1 { + t.Errorf("got reportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + if got := reportStat.Value(); got != 2 { + t.Errorf("got reportStat.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + + // Should have no more packets to send after the initial set of unsolicited + // reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e7c58ae0a..e9ff70d04 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -16,7 +16,9 @@ package ipv4 import ( + "errors" "fmt" + "math" "sync/atomic" "time" @@ -31,6 +33,8 @@ import ( ) const ( + // ReassembleTimeout is the time a packet stays in the reassembly + // system before being evicted. // As per RFC 791 section 3.2: // The current recommendation for the initial timer setting is 15 seconds. // This may be changed as experience with this protocol accumulates. @@ -38,7 +42,7 @@ const ( // Considering that it is an old recommendation, we use the same reassembly // timeout that linux defines, which is 30 seconds: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 - reassembleTimeout = 30 * time.Second + ReassembleTimeout = 30 * time.Second // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -79,6 +83,7 @@ type endpoint struct { sync.RWMutex addressableEndpointState stack.AddressableEndpointState + igmp igmpState } } @@ -89,7 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) + e.mu.igmp.init(e) + e.mu.Unlock() return e } @@ -117,11 +125,22 @@ func (e *endpoint) Enable() *tcpip.Error { // We have no need for the address endpoint. ep.DecRef() + // Groups may have been joined while the endpoint was disabled, or the + // endpoint may have left groups from the perspective of IGMP when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.mu.igmp.initializeAll() + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. - _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) - return err + if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err)) + } + + return nil } // Enabled implements stack.NetworkEndpoint. @@ -153,19 +172,27 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.isEnabled() { return } // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } + // Leave groups from the perspective of IGMP so that routers know that + // we are no longer interested in the group. + e.mu.igmp.softLeaveAll() + // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -176,7 +203,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and @@ -190,39 +221,48 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { - ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) { + hdrLen := header.IPv4MinimumSize + var optLen int + if options != nil { + optLen = int(options.Length()) + } + hdrLen += optLen + if hdrLen > header.IPv4MaximumHeaderSize { + // Since we have no way to report an error we must either panic or create + // a packet which is different to what was requested. Choose panic as this + // would be a programming error that should be caught in testing. + panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize)) + } + ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. - id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: length, ID: uint16(id), TTL: params.TTL, TOS: params.TOS, Protocol: uint8(params.Protocol), - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: srcAddr, + DstAddr: dstAddr, + Options: options, }) ip.SetChecksum(^ip.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) -} - // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) +// original packet. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + // Round the MTU down to align to 8 bytes. + fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader)) var n int for { @@ -239,18 +279,14 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt) -} + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -265,23 +301,43 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) + } return nil } } + return e.writePacket(r, gso, pkt, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, pkt) - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -292,6 +348,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err @@ -310,18 +367,24 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r, pkt, params) - if e.packetMustBeFragmented(pkt, gso) { + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt return nil }); err != nil { - panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err)) + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err)) } // Remove the packet that was just fragmented and process the rest. pkts.Remove(originalPkt) @@ -331,8 +394,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -355,10 +417,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) + } n++ continue } @@ -385,6 +450,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return tcpip.ErrMalformedHeader } + + hdrLen := header.IPv4(h).HeaderLength() + if hdrLen < header.IPv4MinimumSize { + return tcpip.ErrMalformedHeader + } + + h, ok = pkt.Data.PullUp(int(hdrLen)) + if !ok { + return tcpip.ErrMalformedHeader + } ip := header.IPv4(h) // Always set the total length. @@ -406,7 +481,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } @@ -424,19 +499,91 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt) + return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv4(pkt.NetworkHeader().View()) + ttl := h.TTL() + if ttl == 0 { + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 791 page 30, Time to Live, + // + // This field must be decreased at each point that the internet header + // is processed to reflect the time spent processing the datagram. + // Even if no local information is available on the time actually + // spent, the field must be decremented by 1. + newHdr.SetTTL(ttl - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -462,25 +609,52 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // As per RFC 1122 section 3.2.1.3: // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() + return + } + // Make sure the source address is not a subnet-local broadcast address. + if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + if subnet.IsBroadcast(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() + return + } + } + + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + addressEndpoint.DecRef() + pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + } else if !e.IsInGroup(dstAddr) { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) return } // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -488,8 +662,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -502,14 +676,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } - var ready bool - var err error + proto := h.Protocol() - pkt.Data, _, ready, err = e.protocol.fragmentation.Process( + data, _, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -522,30 +695,63 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { start+uint16(pkt.Data.Size())-1, h.More(), proto, - pkt.Data, + pkt, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if !ready { return } + pkt.Data = data + + // The reassembler doesn't take care of fixing up the header, so we need + // to do it here. + h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetFlagsFragmentOffset(0, 0) } + stats.IP.PacketsDelivered.Increment() - r.Stats().IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport // headers, the setting of the transport number here should be // unnecessary and removed. pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt) + e.handleICMP(pkt) + return + } + if p == header.IGMPProtocolNumber { + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() return } + if opts := h.Options(); len(opts) != 0 { + // TODO(gvisor.dev/issue/4586): + // When we add forwarding support we should use the verified options + // rather than just throwing them away. + aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{}) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() + } + return + } + } - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination @@ -553,13 +759,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -578,7 +784,12 @@ func (e *endpoint) Close() { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err == nil { + e.mu.igmp.sendQueuedReports() + } + return ep, err } // RemovePermanentAddress implements stack.AddressableEndpoint. @@ -601,34 +812,26 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo defer e.mu.Unlock() loopback := e.nic.IsLoopback() - addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { - subnet := addressEndpoint.AddressWithPrefix().Subnet() + return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { + subnet := addressEndpoint.Subnet() // IPv4 has a notion of a subnet broadcast address and considers the // loopback interface bound to an address's whole subnet (on linux). return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) - }) - if addressEndpoint != nil { - return addressEndpoint - } - - if !allowTemp { - return nil - } - - addr := localAddr.WithPrefix() - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) - if err != nil { - // AddAddress only returns an error if the address is already assigned, - // but we just checked above if the address exists so we expect no error. - panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) - } - return addressEndpoint + }, allowTemp, tempPEB) } // AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { e.mu.RLock() defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements +// +// Precondition: igmp.ep.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) } @@ -647,32 +850,48 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.mu.igmp.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mu.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mu.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack @@ -693,6 +912,8 @@ type protocol struct { hashIV uint32 fragmentation *fragmentation.Fragmentation + + options Options } // Number returns the ipv4 protocol number. @@ -778,26 +999,32 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload mtu. +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv4MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState } - return mtu - header.IPv4MinimumSize -} -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize + // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in + // length: + // The maximal internet header is 60 octets, and a typical internet header + // is 20 octets, allowing a margin for headers of higher level protocols. + if networkHeaderSize > header.IPv4MaximumHeaderSize { + return 0, tcpip.ErrMalformedHeader } - mtu -= uint32(pkt.NetworkHeader().View().Size()) - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - return mtu + + networkMTU := linkMTU + if networkMTU > MaxTotalSize { + networkMTU = MaxTotalSize + } + + return networkMTU - uint32(networkHeaderSize), nil +} + +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // addressToUint32 translates an IPv4 address into its little endian uint32 @@ -811,17 +1038,23 @@ func addressToUint32(addr tcpip.Address) uint32 { return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 } -// hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number and a 32-bit number to -// generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - a := addressToUint32(r.LocalAddress) - b := addressToUint32(r.RemoteAddress) +// hashRoute calculates a hash value for the given source/destination pair using +// the addresses, transport protocol number and a 32-bit number to generate the +// hash. +func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { + a := addressToUint32(srcAddr) + b := addressToUint32(dstAddr) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -// NewProtocol returns an IPv4 network protocol. -func NewProtocol(s *stack.Stack) stack.NetworkProtocol { +// Options holds options to configure a new protocol. +type Options struct { + // IGMP holds options for IGMP. + IGMP IGMPOptions +} + +// NewProtocolWithOptions returns an IPv4 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -831,21 +1064,31 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + options: opts, + } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } } +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) +} + func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { fragPkt, offset, copied, more := pf.BuildNextFragment() fragPkt.NetworkProtocolNumber = ProtocolNumber originalIPHeaderLength := len(originalIPHeader) nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) @@ -862,3 +1105,338 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head return fragPkt, more } + +// optionAction describes possible actions that may be taken on an option +// while processing it. +type optionAction uint8 + +const ( + // optionRemove says that the option should not be in the output option set. + optionRemove optionAction = iota + + // optionProcess says that the option should be fully processed. + optionProcess + + // optionVerify says the option should be checked and passed unchanged. + optionVerify + + // optionPass says to pass the output set without checking. + optionPass +) + +// optionActions list what to do for each option in a given scenario. +type optionActions struct { + // timestamp controls what to do with a Timestamp option. + timestamp optionAction + + // recordroute controls what to do with a Record Route option. + recordRoute optionAction + + // unknown controls what to do with an unknown option. + unknown optionAction +} + +// optionsUsage specifies the ways options may be operated upon for a given +// scenario during packet processing. +type optionsUsage interface { + actions() optionActions +} + +// optionUsageReceive implements optionsUsage for received packets. +type optionUsageReceive struct{} + +// actions implements optionsUsage. +func (*optionUsageReceive) actions() optionActions { + return optionActions{ + timestamp: optionVerify, + recordRoute: optionVerify, + unknown: optionPass, + } +} + +// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it +// is enabled (Process, Process, Pass) and for fragmenting (Process, Process, +// Pass for frag1, but Remove,Remove,Remove for all other frags). + +// optionUsageEcho implements optionsUsage for echo packet processing. +type optionUsageEcho struct{} + +// actions implements optionsUsage. +func (*optionUsageEcho) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, + unknown: optionRemove, + } +} + +var ( + errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") + errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") + errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") + errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") +) + +// handleTimestamp does any required processing on a Timestamp option +// in place. +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { + flags := tsOpt.Flags() + var entrySize uint8 + switch flags { + case header.IPv4OptionTimestampOnlyFlag: + entrySize = header.IPv4OptionTimestampSize + case + header.IPv4OptionTimestampWithIPFlag, + header.IPv4OptionTimestampWithPredefinedIPFlag: + entrySize = header.IPv4OptionTimestampWithAddrSize + default: + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + } + + pointer := tsOpt.Pointer() + // RFC 791 page 22 states: "The smallest legal value is 5." + // Since the pointer is 1 based, and the header is 4 bytes long the + // pointer must point beyond the header therefore 4 or less is bad. + if pointer <= header.IPv4OptionTimestampHdrLength { + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } + // To simplify processing below, base further work on the array of timestamps + // beyond the header, rather than on the whole option. Also to aid + // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. + nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1) + optLen := tsOpt.Size() + dataLength := optLen - header.IPv4OptionTimestampHdrLength + + // In the section below, we verify the pointer, length and overflow counter + // fields of the option. The distinction is in which byte you return as being + // in error in the ICMP packet. Offsets 1 (length), 2 pointer) + // or 3 (overflowed counter). + // + // The following RFC sections cover this section: + // + // RFC 791 (page 22): + // If there is some room but not enough room for a full timestamp + // to be inserted, or the overflow count itself overflows, the + // original datagram is considered to be in error and is discarded. + // In either case an ICMP parameter problem message may be sent to + // the source host [3]. + // + // You can get this situation in two ways. Firstly if the data area is not + // a multiple of the entry size or secondly, if the pointer is not at a + // multiple of the entry size. The wording of the RFC suggests that + // this is not an error until you actually run out of space. + if pointer > optLen { + // RFC 791 (page 22) says we should switch to using the overflow count. + // If the timestamp data area is already full (the pointer exceeds + // the length) the datagram is forwarded without inserting the + // timestamp, but the overflow count is incremented by one. + if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { + // By definition we have nothing to do. + return 0, nil + } + + if tsOpt.IncOverflow() != 0 { + return 0, nil + } + // The overflow count is also full. + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + } + if nextSlot+entrySize > dataLength { + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. + if false { + // We must select Pointer for Linux compatibility, even if + // only the length is bad. + // The Linux code is at (in October 2020) + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + // which doesn't distinguish between which of optptr[2] or optlen + // is wrong, but just arbitrarily decides on optptr+2. + if dataLength%entrySize != 0 { + // The Data section size should be a multiple of the expected + // timestamp entry size. + return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + } + // If the size is OK, the pointer must be corrupted. + } + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } + + if usage.actions().timestamp == optionProcess { + tsOpt.UpdateTimestamp(localAddress, clock) + } + return 0, nil +} + +var ( + errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") + errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") +) + +// handleRecordRoute checks and processes a Record route option. It is much +// like the timestamp type 1 option, but without timestamps. The passed in +// address is stored in the option in the correct spot if possible. +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { + optlen := rrOpt.Size() + + if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + + pointer := rrOpt.Pointer() + // RFC 791 page 20 states: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // Since the pointer is 1 based, and the header is 3 bytes long the + // pointer must point beyond the header therefore 3 or less is bad. + if pointer <= header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } + + // RFC 791 page 21 says + // If the route data area is already full (the pointer exceeds the + // length) the datagram is forwarded without inserting the address + // into the recorded route. If there is some room but not enough + // room for a full address to be inserted, the original datagram is + // considered to be in error and is discarded. In either case an + // ICMP parameter problem message may be sent to the source + // host. + // The use of the words "In either case" suggests that a 'full' RR option + // could generate an ICMP at every hop after it fills up. We chose to not + // do this (as do most implementations). It is probable that the inclusion + // of these words is a copy/paste error from the timestamp option where + // there are two failure reasons given. + if pointer > optlen { + return 0, nil + } + + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. We must select Pointer for Linux + // compatibility, even if only the length is bad. NB. pointer is 1 based. + if pointer+header.IPv4AddressSize > optlen+1 { + if false { + // This is what we would do if we were not being Linux compatible. + // Check for bad pointer or length value. Must be a multiple of 4 after + // accounting for the 3 byte header and not within that header. + // RFC 791, page 20 says: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // + // A recorded route is composed of a series of internet addresses. + // Each internet address is 32 bits or 4 octets. + // Linux skips this test so we must too. See Linux code at: + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { + // Length is bad, not on integral number of slots. + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + // If not length, the fault must be with the pointer. + } + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } + if usage.actions().recordRoute == optionVerify { + return 0, nil + } + rrOpt.StoreAddress(localAddress) + return 0, nil +} + +// processIPOptions parses the IPv4 options and produces a new set of options +// suitable for use in the next step of packet processing as informed by usage. +// The original will not be touched. +// +// Returns +// - The location of an error if there was one (or 0 if no error) +// - If there is an error, information as to what it was was. +// - The replacement option set. +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + stats := e.protocol.stack.Stats() + opts := header.IPv4Options(orig) + optIter := opts.MakeIterator() + + // Each option other than NOP must only appear (RFC 791 section 3.1, at the + // definition of every type). Keep track of each of the possible types in + // the 8 bit 'type' field. + var seenOptions [math.MaxUint8 + 1]bool + + // TODO(gvisor.dev/issue/4586): + // This will need tweaking when we start really forwarding packets + // as we may need to get two addresses, for rx and tx interfaces. + // We will also have to take usage into account. + prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) + localAddress := prefixedAddress.Address + if err != nil { + h := header.IPv4(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { + return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + } + localAddress = dstAddr + } + + for { + option, done, err := optIter.Next() + if done || err != nil { + return optIter.ErrCursor, optIter.Finalize(), err + } + optType := option.Type() + if optType == header.IPv4OptionNOPType { + optIter.PushNOPOrEnd(optType) + continue + } + if optType == header.IPv4OptionListEndType { + optIter.PushNOPOrEnd(optType) + return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + } + + // check for repeating options (multiple NOPs are OK) + if seenOptions[optType] { + return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + } + seenOptions[optType] = true + + optLen := int(option.Size()) + switch option := option.(type) { + case *header.IPv4OptionTimestamp: + stats.IP.OptionTSReceived.Increment() + if usage.actions().timestamp != optionRemove { + clock := e.protocol.stack.Clock() + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + case *header.IPv4OptionRecordRoute: + stats.IP.OptionRRReceived.Increment() + if usage.actions().recordRoute != optionRemove { + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + default: + stats.IP.OptionUnknownReceived.Increment() + if usage.actions().unknown == optionPass { + newBuffer := optIter.RemainingBuffer()[:optLen] + // Arguments already heavily checked.. ignore result. + _ = copy(newBuffer, option.Contents()) + optIter.ConsumeBuffer(optLen) + } + } + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index fee11bb38..9e2d2cfd6 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -21,11 +21,13 @@ import ( "math" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -39,7 +41,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -const extraHeaderReserve = 50 +const ( + extraHeaderReserve = 50 + defaultMTU = 65536 +) func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ @@ -47,7 +52,6 @@ func TestExcludeBroadcast(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { ep = sniffer.New(ep) @@ -99,11 +103,167 @@ func TestExcludeBroadcast(t *testing.T) { }) } +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv4Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + } + ipv4Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + } + remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) + remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: false, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} + if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + } + + e2 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} + if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv4Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + } + + totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: test.TTL, + SrcAddr: remoteIPv4Addr1, + DstAddr: remoteIPv4Addr2, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv4Addr1.Address), + checker.DstAddr(remoteIPv4Addr1), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4TTLExceeded), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo packet through outgoing NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv4Addr1), + checker.DstAddr(remoteIPv4Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} + // TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and // checks the response. func TestIPv4Sanity(t *testing.T) { const ( - defaultMTU = header.IPv6MinimumMTU ttl = 255 nicID = 1 randomSequence = 123 @@ -118,27 +278,29 @@ func TestIPv4Sanity(t *testing.T) { ) tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - shouldFail bool - expectICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - options []byte + name string + headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + options header.IPv4Options + replyOptions header.IPv4Options // reply should look like this + shouldFail bool + expectErrorICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + paramProblemPointer uint8 }{ { - name: "valid", - maxTotalLength: defaultMTU, + name: "valid no options", + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, { name: "bad header checksum", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, badHeaderChecksum: true, @@ -157,47 +319,47 @@ func TestIPv4Sanity(t *testing.T) { // received with TTL less than 2. { name: "zero TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 0, - shouldFail: false, }, { name: "one TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 1, - shouldFail: false, }, { name: "End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, - options: []byte{0, 0, 0, 0}, + options: header.IPv4Options{0, 0, 0, 0}, + replyOptions: header.IPv4Options{0, 0, 0, 0}, }, { name: "NOP options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, - options: []byte{1, 1, 1, 1}, + options: header.IPv4Options{1, 1, 1, 1}, + replyOptions: header.IPv4Options{1, 1, 1, 1}, }, { name: "NOP and End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, - options: []byte{1, 1, 0, 0}, + options: header.IPv4Options{1, 1, 0, 0}, + replyOptions: header.IPv4Options{1, 1, 0, 0}, }, { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (0)", @@ -205,7 +367,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip - 1)", @@ -213,7 +374,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip + icmp - 1)", @@ -221,28 +381,465 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad protocol", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: 99, TTL: ttl, shouldFail: true, - expectICMP: true, + expectErrorICMP: true, ICMPType: header.ICMPv4DstUnreachable, ICMPCode: header.ICMPv4ProtoUnreachable, }, + { + name: "timestamp option overflow", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + replyOptions: header.IPv4Options{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + }, + { + name: "timestamp option overflow full", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 12, 13, 0xF1, + // ^ Counter full (15/0xF) + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + replyOptions: header.IPv4Options{}, + }, + { + name: "unknown option", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{10, 4, 9, 0}, + // ^^ + // The unknown option should be stripped out of the reply. + replyOptions: header.IPv4Options{}, + }, + { + name: "bad option - length 0", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 0, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + name: "bad option - length big", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 9, 9, 0, + // ^ + // There are only 8 bytes allocated to options so 9 bytes of timestamp + // space is not possible. (Second byte) + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + // This tests for some linux compatible behaviour. + // The ICMP pointer returned is 22 for Linux but the + // error is actually in spot 21. + name: "bad option - length bad", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + // Timestamps are in multiples of 4 or 8 but never 7. + // The option space should be padded out. + options: header.IPv4Options{ + 68, 7, 5, 0, + // ^ ^ Linux points here which is wrong. + // | Not a multiple of 4 + 1, 2, 3, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "multiple type 0 with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 24, 21, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + }, + replyOptions: header.IPv4Options{ + 68, 24, 25, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // The timestamp area is full so add to the overflow count. + name: "multiple type 1 timestamps", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 20, 21, 0x11, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + // Overflow count is the top nibble of the 4th byte. + replyOptions: header.IPv4Options{ + 68, 20, 21, 0x21, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + }, + { + name: "multiple type 1 timestamps with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 28, 21, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + replyOptions: header.IPv4Options{ + 68, 28, 29, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 192, 168, 1, 58, // New IP Address. + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // Timestamp pointer uses one based counting so 0 is invalid. + name: "timestamp pointer invalid", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, 0, 0x00, + // ^ 0 instead of 5 or more. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Timestamp pointer cannot be less than 5. It must point past the header + // which is 4 bytes. (1 based counting) + name: "timestamp pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, + // ^ header is 4 bytes, so 4 should fail. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "valid timestamp pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, + // ^ header is 4 bytes, so 5 should succeed. + 0, 0, 0, 0, + }, + replyOptions: header.IPv4Options{ + 68, 8, 9, 0x00, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // Needs 8 bytes for a type 1 timestamp but there are only 4 free. + name: "bad timer element alignment", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 20, 17, 0x01, + // ^^ ^^ 20 byte area, next free spot at 17. + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + // End of option list with illegal option after it, which should be ignored. + { + name: "end of options list", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 10, 3, 99, + }, + replyOptions: header.IPv4Options{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, // 3 bytes unknown option + }, // ^ End of options hides following bytes. + }, + { + // Timestamp with a size too small. + name: "timestamp truncated", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{68, 1, 0, 0}, + // ^ Smallest possible is 8. + shouldFail: true, + }, + { + name: "single record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 7, 4, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "multiple record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 23, 20, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "single record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Unlike timestamp, this should just succeed. + name: "multiple record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 23, 24, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Pointer uses one based counting so 0 is invalid. + name: "record route pointer zero", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, 0, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. 3 should fail. + name: "record route pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. Check 4 passes. (Duplicates "single + // record route with room") + name: "valid record route pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm Linux bug for bug compatibility. + // Linux returns slot 22 but the error is in slot 21. + name: "multiple record route with not enough room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, 8, // 3 byte header + // ^ ^ Linux points here. We must too. + // | Not enough room. 1 byte free, need 4. + 1, 2, 3, 4, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + replyOptions: header.IPv4Options{}, + }, + { + name: "duplicate record route", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, 0, // pad + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 7, + replyOptions: header.IPv4Options{}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, }) // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") + e := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -250,6 +847,9 @@ func TestIPv4Sanity(t *testing.T) { if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } + // Advance the clock by some unimportant amount to make + // sure it's all set up. + clock.Advance(time.Millisecond * 0x10203040) // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. @@ -260,14 +860,12 @@ func TestIPv4Sanity(t *testing.T) { }, }) - // Round up the header size to the next multiple of 4 as RFC 791, page 11 - // says: "Internet Header Length is the length of the internet header - // in 32 bit words..." and on page 23: "The internet header padding is - // used to ensure that the internet header ends on a 32 bit boundary." - ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1) - + if len(test.options)%4 != 0 { + t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) + } + ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) @@ -284,20 +882,26 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } + ip.Encode(&header.IPv4Fields{ - IHL: uint8(ipHeaderLength), TotalLength: totalLen, Protocol: test.transportProtocol, TTL: test.TTL, SrcAddr: remoteIPv4Addr, DstAddr: ipv4Addr.Address, }) - if n := copy(ip.Options(), test.options); n != len(test.options) { - t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options)) - } - // Override the correct value if the test case specified one. if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) + } else { + // Set the calculated header length, since we may manually add options. + ip.SetHeaderLength(uint8(ipHeaderLength)) + } + if len(test.options) != 0 { + // Copy options manually. We do not use Encode for options so we can + // verify malformed options with handcrafted payloads. + if want, got := copy(ip.Options(), test.options), len(test.options); want != got { + t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) + } } ip.SetChecksum(0) ipHeaderChecksum := ip.CalculateChecksum() @@ -312,14 +916,20 @@ func TestIPv4Sanity(t *testing.T) { reply, ok := e.Read() if !ok { if test.shouldFail { - if test.expectICMP { - t.Fatal("expected ICMP error response missing") + if test.expectErrorICMP { + t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) } return // Expected silent failure. } t.Fatal("expected ICMP echo reply missing") } + // We didn't expect a packet. Register our surprise but carry on to + // provide more information about what we got. + if test.shouldFail && !test.expectErrorICMP { + t.Error("unexpected packet response") + } + // Check the route that brought the packet to us. if reply.Route.LocalAddress != ipv4Addr.Address { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) @@ -328,57 +938,90 @@ func TestIPv4Sanity(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) } - // Make sure it's all in one buffer. - vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) - replyIPHeader := header.IPv4(vv.ToView()) + // Make sure it's all in one buffer for checker. + replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - // At this stage we only know it's an IP header so verify that much. + // At this stage we only know it's probably an IP+ICMP header so verify + // that much. checker.IPv4(t, replyIPHeader, checker.SrcAddr(ipv4Addr.Address), checker.DstAddr(remoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Checksum(), + ), ) - // All expected responses are ICMP packets. - if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { - t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + // Don't proceed any further if the checker found problems. + if t.Failed() { + t.FailNow() } - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - // Sanity check the response. + // OK it's ICMP. We can safely look at the type now. + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) switch replyICMPHeader.Type() { - case header.ICMPv4DstUnreachable: + case header.ICMPv4ParamProblem: + if !test.shouldFail { + t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) + } + if !test.expectErrorICMP { + t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) + } checker.IPv4(t, replyIPHeader, checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Checksum(), + checker.ICMPv4Pointer(test.paramProblemPointer), checker.ICMPv4Payload([]byte(hdr.View())), ), ) - if !test.shouldFail || !test.expectICMP { - t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + return + case header.ICMPv4DstUnreachable: + if !test.shouldFail { + t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + if !test.expectErrorICMP { + t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", header.ICMPv4DstUnreachable, replyICMPHeader.Code()) } + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) return case header.ICMPv4EchoReply: + if test.shouldFail { + if !test.expectErrorICMP { + t.Error("got Echo Reply packet, want no response") + } else { + t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) + } + } + // If the IP options change size then the packet will change size, so + // some IP header fields will need to be adjusted for the checks. + sizeChange := len(test.replyOptions) - len(test.options) + checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength), - checker.IPv4Options(test.options), - checker.IPFullLength(uint16(requestPkt.Size())), + checker.IPv4HeaderLength(ipHeaderLength+sizeChange), + checker.IPv4Options(test.replyOptions), + checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), checker.ICMPv4( + checker.ICMPv4Checksum(), checker.ICMPv4Code(header.ICMPv4UnusedCode), checker.ICMPv4Seq(randomSequence), checker.ICMPv4Ident(randomIdent), - checker.ICMPv4Checksum(), ), ) - if test.shouldFail { - t.Fatalf("unexpected Echo Reply packet\n") - } default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) } }) } @@ -462,7 +1105,7 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", + description: "No fragmentation", mtu: 1280, gso: nil, transportHeaderLength: 0, @@ -483,6 +1126,30 @@ var fragmentationTests = []struct { }, }, { + description: "Fragmented with the minimum mtu", + mtu: header.IPv4MinimumMTU, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv4MinimumMTU + 1, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { description: "No fragmentation with big header", mtu: 2000, gso: nil, @@ -647,43 +1314,50 @@ func TestFragmentationWritePackets(t *testing.T) { } } -// TestFragmentationErrors checks that errors are returned from write packet +// TestFragmentationErrors checks that errors are returned from WritePacket // correctly. func TestFragmentationErrors(t *testing.T) { const ttl = 42 - expectedError := tcpip.ErrAborted - fragTests := []struct { + tests := []struct { description string mtu uint32 transportHeaderLength int payloadSize int allowPackets int - fragmentCount int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error }{ { description: "No frag", mtu: 2000, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 1, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 3, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on second frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 1, - fragmentCount: 3, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag MTU smaller than header", @@ -691,28 +1365,40 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 1000, payloadSize: 500, allowPackets: 0, - fragmentCount: 4, + outgoingErrors: 4, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error when MTU is smaller than IPv4 minimum MTU", + mtu: header.IPv4MinimumMTU - 1, + transportHeaderLength: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, }, } - for _, ft := range fragTests { + for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) - r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != expectedError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) } - if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) } }) } @@ -739,12 +1425,13 @@ func TestInvalidFragments(t *testing.T) { } type fragmentData struct { - ipv4fields header.IPv4Fields + ipv4fields header.IPv4Fields + // 0 means insert the correct IHL. Non 0 means override the correct IHL. + overrideIHL int // For 0 use 1 as it is an int and will be divided by 4. payload []byte - autoChecksum bool // if true, the Checksum field will be overwritten. + autoChecksum bool // If true, the Checksum field will be overwritten. } - // These packets have both IHL and TotalLength set to 0. tests := []struct { name string fragments []fragmentData @@ -756,7 +1443,6 @@ func TestInvalidFragments(t *testing.T) { fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ - IHL: 0, TOS: tos, TotalLength: 0, ID: ident, @@ -767,6 +1453,7 @@ func TestInvalidFragments(t *testing.T) { SrcAddr: addr1, DstAddr: addr2, }, + overrideIHL: 1, // See note above. payload: payloadGen(12), autoChecksum: true, }, @@ -779,7 +1466,6 @@ func TestInvalidFragments(t *testing.T) { fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ - IHL: 0, TOS: tos, TotalLength: 0, ID: ident, @@ -790,6 +1476,7 @@ func TestInvalidFragments(t *testing.T) { SrcAddr: addr1, DstAddr: addr2, }, + overrideIHL: 1, // See note above. payload: payloadGen(12), autoChecksum: true, }, @@ -804,7 +1491,6 @@ func TestInvalidFragments(t *testing.T) { fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TOS: tos, TotalLength: header.IPv4MinimumSize + 17, ID: ident, @@ -829,7 +1515,6 @@ func TestInvalidFragments(t *testing.T) { fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TOS: tos, TotalLength: header.IPv4MinimumSize + 16, ID: ident, @@ -852,7 +1537,6 @@ func TestInvalidFragments(t *testing.T) { fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize - 12, TOS: tos, TotalLength: header.IPv4MinimumSize + 28, ID: ident, @@ -864,11 +1548,11 @@ func TestInvalidFragments(t *testing.T) { DstAddr: addr2, }, payload: payloadGen(28), + overrideIHL: header.IPv4MinimumSize - 12, autoChecksum: true, }, { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize - 12, TOS: tos, TotalLength: header.IPv4MinimumSize - 12, ID: ident, @@ -880,6 +1564,7 @@ func TestInvalidFragments(t *testing.T) { DstAddr: addr2, }, payload: payloadGen(28), + overrideIHL: header.IPv4MinimumSize - 12, autoChecksum: true, }, }, @@ -891,7 +1576,6 @@ func TestInvalidFragments(t *testing.T) { fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize + 4, TOS: tos, TotalLength: header.IPv4MinimumSize + 28, ID: ident, @@ -903,11 +1587,11 @@ func TestInvalidFragments(t *testing.T) { DstAddr: addr2, }, payload: payloadGen(28), + overrideIHL: header.IPv4MinimumSize + 4, autoChecksum: true, }, { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize + 4, TOS: tos, TotalLength: header.IPv4MinimumSize + 4, ID: ident, @@ -919,6 +1603,7 @@ func TestInvalidFragments(t *testing.T) { DstAddr: addr2, }, payload: payloadGen(28), + overrideIHL: header.IPv4MinimumSize + 4, autoChecksum: true, }, }, @@ -930,7 +1615,6 @@ func TestInvalidFragments(t *testing.T) { fragments: []fragmentData{ { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, @@ -946,7 +1630,6 @@ func TestInvalidFragments(t *testing.T) { }, { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, @@ -962,7 +1645,6 @@ func TestInvalidFragments(t *testing.T) { }, { ipv4fields: header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TOS: tos, TotalLength: header.IPv4MinimumSize + 8, ID: ident, @@ -984,7 +1666,6 @@ func TestInvalidFragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -1004,6 +1685,11 @@ func TestInvalidFragments(t *testing.T) { ip := header.IPv4(hdr.Prepend(pktSize)) ip.Encode(&f.ipv4fields) + // Encode sets this up correctly. If we want a different value for + // testing then we need to overwrite the good value. + if f.overrideIHL != 0 { + ip.SetHeaderLength(uint8(f.overrideIHL)) + } copy(ip[header.IPv4MinimumSize:], f.payload) if f.autoChecksum { @@ -1027,6 +1713,251 @@ func TestInvalidFragments(t *testing.T) { } } +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 99 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + { + ipv4fields: header.IPv4Fields{ + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + }, + Clock: clock, + }) + e := channel.New(1, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + hdr := buffer.NewPrependable(pktSize) + + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && ip.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(header.IPv4ProtocolNumber, pkt) + } + + clock.Advance(ipv4.ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(firstFragmentSent)), + ), + ) + }) + } +} + // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { @@ -1392,6 +2323,28 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, + { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { @@ -1432,7 +2385,6 @@ func TestReceiveFragments(t *testing.T) { // Serialize IPv4 fixed header. ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), ID: frag.id, Flags: frag.flags, @@ -1506,13 +2458,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, @@ -1527,17 +2476,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, @@ -1577,7 +2523,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -1592,7 +2538,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -1609,7 +2555,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) @@ -1704,7 +2650,6 @@ func TestPacketQueing(t *testing.T) { u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, TTL: ipv4.DefaultTTL, Protocol: uint8(udp.ProtocolNumber), @@ -1724,8 +2669,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -1748,7 +2693,6 @@ func TestPacketQueing(t *testing.T) { pkt.SetChecksum(^header.Checksum(pkt, 0)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(totalLen), Protocol: uint8(icmp.ProtocolNumber4), TTL: ipv4.DefaultTTL, @@ -1768,8 +2712,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -1783,7 +2727,7 @@ func TestPacketQueing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, @@ -1793,9 +2737,6 @@ func TestPacketQueing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) } @@ -1820,8 +2761,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a30437f02..afa45aefe 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -8,6 +8,7 @@ go_library( "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "mld.go", "ndp.go", ], visibility = ["//visibility:public"], @@ -19,6 +20,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -36,6 +38,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", @@ -48,3 +51,19 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "ipv6_x_test", + size = "small", + srcs = ["mld_test.go"], + deps = [ + ":ipv6", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ead6bedcb..6ee162713 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -124,10 +124,10 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { }) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := r.Stats().ICMP - sent := stats.V6PacketsSent - received := stats.V6PacketsReceived +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { + stats := e.protocol.stack.Stats().ICMP + sent := stats.V6.PacketsSent + received := stats.V6.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) + srcAddr := iph.SourceAddress() + dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. // // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) - if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { received.Invalid.Increment() return } @@ -161,7 +163,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } // TODO(b/112892170): Meaningfully handle all ICMP types. - switch h.Type() { + switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) @@ -170,8 +172,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -221,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // we know we are also performing DAD on it). In this case we let the // stack know so it can handle such a scenario and do nothing further with // the NS. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { // We would get an error if the address no longer exists or the address // is no longer tentative (DAD resolved between the call to // hasTentativeAddr and this point). Both of these are valid scenarios: @@ -248,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -274,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. - unspecifiedSource := r.RemoteAddress == header.IPv6Any + unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { - if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { received.Invalid.Increment() return } @@ -284,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -295,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // ... // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. - if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { received.Invalid.Increment() return } @@ -305,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the source of the solicitation is the unspecified address, the node // MUST [...] and multicast the advertisement to the all-nodes address. // - remoteAddr := r.RemoteAddress + remoteAddr := srcAddr if unspecifiedSource { remoteAddr = header.IPv6AllNodesMulticastAddress } @@ -353,7 +358,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(packet.NDPPayload()) + na := header.NDPNeighborAdvert(packet.MessageBody()) // As per RFC 4861 section 7.2.4: // @@ -462,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. - localAddr := r.LocalAddress - if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr := dstAddr + if header.IsV6MulticastAddress(dstAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -483,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -495,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -516,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - stack := r.Stack() + stack := e.protocol.stack // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { @@ -547,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { received.Invalid.Increment() return } @@ -555,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } } @@ -572,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - routerAddr := iph.SourceAddress() + routerAddr := srcAddr // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { @@ -605,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } e.mu.Lock() @@ -635,8 +644,39 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } + case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + received.MulticastListenerQuery.Increment() + case header.ICMPv6MulticastListenerReport: + received.MulticastListenerReport.Increment() + case header.ICMPv6MulticastListenerDone: + received.MulticastListenerDone.Increment() + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { + received.Invalid.Increment() + return + } + + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + e.mu.Lock() + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerReport: + e.mu.Lock() + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerDone: + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + default: - received.Invalid.Increment() + received.Unrecognized.Increment() } } @@ -648,52 +688,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the - // route here. Note, we would need the nicID to do this properly so the right - // NIC (associated to linkEP) is used to send the NDP NS message. - r := stack.Route{ - LocalAddress: localAddr, - RemoteAddress: addr, - LocalLinkAddress: linkEP.LinkAddress(), - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + remoteAddr := targetAddr + if len(remoteLinkAddr) == 0 { + remoteAddr = header.SolicitedNodeAddr(targetAddr) + remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - // If a remote address is not already known, then send a multicast - // solicitation since multicast addresses have a static mapping to link - // addresses. - if len(r.RemoteLinkAddress) == 0 { - r.RemoteAddress = header.SolicitedNodeAddr(addr) - r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) + r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err } + defer r.Release() + r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize, + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(packet.NDPPayload()) - ns.SetTargetAddress(addr) + ns := header.NDPNeighborSolicit(packet.MessageBody()) + ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) + stat := p.stack.Stats().ICMP.V6.PacketsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt); err != nil { + stat.Dropped.Increment() + return err + } - // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt) + stat.NeighborSolicit.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -747,9 +781,26 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in +// transit to its final destination, as per RFC 4443 section 3.3. +type icmpReasonHopLimitExceeded struct{} + +func (*icmpReasonHopLimitExceeded) isICMPReason() {} + +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv6(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // @@ -776,32 +827,49 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) + if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { return nil } + // If we hit a Hop Limit Exceeded error, then we know we are operating as a + // router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + // + // If we are operating as a router, do not use the packet's destination + // address as the response's source address as we should not own the + // destination address of a packet we are forwarding. + // + // If the packet was originally destined to a multicast address, then do not + // use the packet's destination address as the source for the response ICMP + // packet as "multicast addresses must not be used as source addresses in IPv6 + // packets", as per RFC 4291 section 2.7. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil stats := p.stack.Stats().ICMP - sent := stats.V6PacketsSent + sent := stats.V6.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil } - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. // Unfortunately at this time ICMP Packets do not have a transport @@ -819,6 +887,8 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac } } + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + // As per RFC 4443 section 2.4 // // (c) Every ICMPv6 error message (type < 128) MUST include @@ -839,7 +909,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload := network.ToVectorisedView() + payload.AppendView(transport) + payload.Append(pkt.Data) payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -860,6 +932,14 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonHopLimitExceeded: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) + counter = sent.TimeExceeded + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) + counter = sent.TimeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -879,3 +959,16 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message as per RFC 2460 Section + // 4.5: + // + // If the first fragment (i.e., the one with a Fragment Offset of zero) has + // been received, an ICMP Time Exceeded -- Fragment Reassembly Time Exceeded + // message should be sent to the source of that fragment. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8dc33c560..02b18e9a5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -51,6 +51,7 @@ const ( var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) + lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -86,7 +87,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { +func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { return stack.TransportPacketHandled } @@ -108,31 +109,27 @@ type stubNUDHandler struct { var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { s.probeCount++ } -func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { s.confirmationCount++ } -func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { +func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { } var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - stack.NetworkLinkEndpoint - - linkAddr tcpip.LinkAddress -} + stack.LinkEndpoint -func (i *testInterface) LinkAddress() tcpip.LinkAddress { - return i.linkAddr + nicID tcpip.NICID } func (*testInterface) ID() tcpip.NICID { - return 0 + return nicID } func (*testInterface) IsLoopback() bool { @@ -147,6 +144,18 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + } + r.ResolveWith(remoteLinkAddr) + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -169,13 +178,8 @@ func TestICMPCounts(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -201,11 +205,16 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -262,6 +271,22 @@ func TestICMPCounts(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -271,20 +296,20 @@ func TestICMPCounts(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - ep.HandlePacket(&r, pkt) + ep.HandlePacket(pkt) } for _, typ := range types { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -292,7 +317,7 @@ func TestICMPCounts(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -311,13 +336,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -343,11 +363,16 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -404,6 +429,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -413,20 +454,20 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - ep.HandlePacket(&r, pkt) + ep.HandlePacket(pkt) } for _, typ := range types { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -434,7 +475,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -559,8 +600,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr { + t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr) } // Pull the full payload since network header. Needed for header.IPv6 to @@ -812,11 +853,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), @@ -824,7 +865,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -889,11 +930,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1007,11 +1048,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(icmpSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1019,7 +1060,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1067,11 +1108,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1186,11 +1227,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(size + payloadSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), @@ -1198,7 +1239,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1235,26 +1276,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + snaddr := header.SolicitedNodeAddr(lladdr0) mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) tests := []struct { - name string - remoteLinkAddr tcpip.LinkAddress - expectedLinkAddr tcpip.LinkAddress - expectedAddr tcpip.Address + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedRemoteAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", - remoteLinkAddr: linkAddr1, - expectedLinkAddr: linkAddr1, - expectedAddr: lladdr0, + name: "Unicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, + }, + { + name: "Multicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, }, { - name: "Multicast", - remoteLinkAddr: "", - expectedLinkAddr: mcaddr, - expectedAddr: snaddr, + name: "Multicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Unicast with no local address available", + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, }, } @@ -1269,26 +1356,43 @@ func TestLinkAddressRequest(t *testing.T) { } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } + } + + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return } pkt, ok := linkEP.Read() if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr) + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) } - if pkt.Route.RemoteAddress != test.expectedAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr) + if pkt.Route.RemoteAddress != test.expectedRemoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) } if pkt.Route.LocalAddress != lladdr1 { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) } checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedAddr), + checker.DstAddr(test.expectedRemoteAddr), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(lladdr0), @@ -1341,11 +1445,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1359,8 +1463,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1383,11 +1487,11 @@ func TestPacketQueing(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1401,8 +1505,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1452,8 +1556,8 @@ func TestPacketQueing(t *testing.T) { t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) } snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1471,7 +1575,7 @@ func TestPacketQueing(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) @@ -1482,11 +1586,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1520,7 +1624,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1540,7 +1644,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1557,7 +1661,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1573,7 +1677,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1590,7 +1694,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1611,7 +1715,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1630,7 +1734,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1650,7 +1754,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1698,37 +1802,39 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() - - // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. - r.LocalAddress = test.destination icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.IPv6MinimumSize, Data: buffer.View(icmp).ToVectorisedView(), }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.RemoteAddress, - DstAddr: r.LocalAddress, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.source, + DstAddr: test.destination, }) - ep.HandlePacket(&r, pkt) + ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. if nudHandler.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 9670696c7..a49b5ac77 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "fmt" "hash/fnv" + "math" "sort" "sync/atomic" "time" @@ -34,19 +35,21 @@ import ( ) const ( + // ReassembleTimeout controls how long a fragment will be held. // As per RFC 8200 section 4.5: + // // If insufficient fragments are received to complete reassembly of a packet // within 60 seconds of the reception of the first-arriving fragment of that // packet, reassembly of that packet must be abandoned. // // Linux also uses 60 seconds for reassembly timeout: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 - reassembleTimeout = 60 * time.Second + ReassembleTimeout = 60 * time.Second // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber - // maxTotalSize is maximum size that can be encoded in the 16-bit + // maxPayloadSize is the maximum size that can be encoded in the 16-bit // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff @@ -83,6 +86,7 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState + mld mldState } } @@ -118,6 +122,45 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// onAddressAssignedLocked handles an address being assigned. +// +// Precondition: e.mu must be exclusively locked. +func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, ... + // + // If we just completed DAD for a link-local address, then attempt to send any + // queued MLD reports. Note, we may have sent reports already for some of the + // groups before we had a valid link-local address to use as the source for + // the MLD messages, but that was only so that MLD snooping switches are aware + // of our membership to groups - routers would not have handled those reports. + // + // As per RFC 3590 section 4, + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + if header.IsV6LinkLocalAddress(addr) { + e.mu.mld.sendQueuedReports() + } +} + // InvalidateDefaultRouter implements stack.NDPEndpoint. func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() @@ -166,7 +209,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return err } - prefix := addressEndpoint.AddressWithPrefix().Subnet() + prefix := addressEndpoint.Subnet() switch t := addressEndpoint.ConfigType(); t { case stack.AddressConfigStatic: @@ -224,6 +267,12 @@ func (e *endpoint) Enable() *tcpip.Error { return nil } + // Groups may have been joined when the endpoint was disabled, or the + // endpoint may have left groups from the perspective of MLD when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.mu.mld.initializeAll() + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -241,8 +290,10 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { - return err + if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err)) } // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent @@ -251,7 +302,7 @@ func (e *endpoint) Enable() *tcpip.Error { // Addresses may have aleady completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. var err *tcpip.Error - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { addr := addressEndpoint.AddressWithPrefix().Address if !header.IsV6UnicastAddress(addr) { return true @@ -273,7 +324,7 @@ func (e *endpoint) Enable() *tcpip.Error { } // Do not auto-generate an IPv6 link-local address for loopback devices. - if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() { // The valid and preferred lifetime is infinite for the auto-generated // link-local address. e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) @@ -322,7 +373,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.Enabled() { return } @@ -331,9 +382,17 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } + + // Leave groups from the perspective of MLD so that routers know that + // we are no longer interested in the group. + e.mu.mld.softLeaveAll() + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -341,7 +400,7 @@ func (e *endpoint) disableLocked() { // Precondition: e.mu must be write locked. func (e *endpoint) stopDADForPermanentAddressesLocked() { // Stop DAD for all the tentative unicast addresses. - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.GetKind() != stack.PermanentTentative { return true } @@ -363,50 +422,75 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { + // TODO(gvisor.dev/issues/5035): The maximum header length returned here does + // not open the possibility for the caller to know about size required for + // extension headers. return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) { + extHdrsLen := extensionHeaders.Length() + length := pkt.Size() + extensionHeaders.Length() + if length > math.MaxUint16 { + panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16)) + } + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(params.Protocol), - HopLimit: params.TTL, - TrafficClass: params.TOS, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(length), + TransportProtocol: params.Protocol, + HopLimit: params.TTL, + TrafficClass: params.TOS, + SrcAddr: srcAddr, + DstAddr: dstAddr, + ExtensionHeaders: extensionHeaders, }) pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. The transport -// header protocol number is required to avoid parsing the IPv6 extension -// headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - if fragMTU < pkt.TransportHeader().View().Size() { +// original packet. The transport header protocol number is required to avoid +// parsing the IPv6 extension headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // maximum payload length because they should only be transmitted once. + fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7 + if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit { + // We need at least 8 bytes of space left for the fragmentable part because + // the fragment payload must obviously be non-zero and must be a multiple + // of 8 as per RFC 8200 section 4.5: + // Each complete fragment, except possibly the last ("rightmost") one, is + // an integer multiple of 8 octets long. + return 0, 1, tcpip.ErrMessageTooLong + } + + if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. return 0, 1, tcpip.ErrMessageTooLong } - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) - networkHeader := header.IPv6(pkt.NetworkHeader().View()) var n int for { @@ -423,18 +507,14 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt, params.Protocol) -} + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */) -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -448,28 +528,43 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) + } return nil } } - if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ - // The inbound path expects an unparsed packet. - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) + return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) +} - loopedR.Release() +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { + if r.Loop&stack.PacketLoop != 0 { + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -499,13 +594,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r, pb, params) - if e.packetMustBeFragmented(pb, gso) { + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */) + + networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + if packetMustBeFragmented(pb, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -522,8 +624,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -546,10 +647,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) + } n++ continue } @@ -569,7 +673,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required checks. h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) @@ -602,27 +706,115 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt, proto) + return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv6(pkt.NetworkHeader().View()) + hopLimit := h.HopLimit() + if hopLimit <= 1 { + // As per RFC 4443 section 3.3, + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 8200 section 3, + // + // Hop Limit 8-bit unsigned integer. Decremented by 1 by + // each node that forwards the packet. + newHdr.SetHopLimit(hopLimit - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() // As per RFC 4291 section 2.7: // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. - if header.IsV6MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if header.IsV6MulticastAddress(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() + return + } + + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + } else if !e.IsInGroup(dstAddr) { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) return } @@ -638,10 +830,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -651,7 +842,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -663,7 +854,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: previousHeaderStart, }, pkt) @@ -675,7 +866,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -689,7 +880,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -702,7 +893,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -727,7 +918,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // header, so we just make sure Segments Left is zero before processing // the next extension header. if extHdr.SegmentsLeft() != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), }, pkt) @@ -747,6 +938,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { continue } + fragmentFieldOffset := it.ParseOffset() + // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. @@ -762,8 +955,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if done { @@ -790,8 +983,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } } @@ -799,19 +992,47 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + return + } + + // As per RFC 2460 Section 4.5: + // + // If the length of a fragment, as derived from the fragment packet's + // Payload Length field, is not a multiple of 8 octets and the M flag + // of that fragment is 1, then that fragment must be discarded and an + // ICMP Parameter Problem, Code 0, message should be sent to the source + // of the fragment, pointing to the Payload Length field of the + // fragment packet. + if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: header.IPv6PayloadLenOffset, + }, pkt) return } // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - // Drop the fragment if the size of the reassembled payload would exceed - // the maximum payload size. + // As per RFC 2460 Section 4.5: + // + // If the length and offset of a fragment are such that the Payload + // Length of the packet reassembled from that fragment would exceed + // 65,535 octets, then that fragment must be discarded and an ICMP + // Parameter Problem, Code 0, message should be sent to the source of + // the fragment, pointing to the Fragment Offset field of the fragment + // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: fragmentFieldOffset, + }, pkt) return } @@ -821,24 +1042,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ - Source: h.SourceAddress(), - Destination: h.DestinationAddress(), + Source: srcAddr, + Destination: dstAddr, ID: extHdr.ID(), }, start, start+uint16(fragmentPayloadLen)-1, extHdr.More(), uint8(rawPayload.Identifier), - rawPayload.Buf, + pkt, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } - pkt.Data = data if ready { + pkt.Data = data + // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC // 8200 section 4.5. We also use the NextHeader value from the first @@ -852,7 +1074,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -866,7 +1088,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -879,7 +1101,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -902,13 +1124,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - r.Stats().IP.PacketsDelivered.Increment() + stats.IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt, hasFragmentHeader) + e.handleICMP(pkt, hasFragmentHeader) } else { - r.Stats().IP.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + stats.IP.PacketsDelivered.Increment() + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -916,7 +1138,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -937,9 +1159,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + // The location of the Next Header field is in a different place in + // the initial IPv6 header than it is in the extension headers so + // treat it specially. + prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset) + if previousHeaderStart != 0 { + prevHdrIDOffset = previousHeaderStart + } + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), + pointer: prevHdrIDOffset, }, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) @@ -947,12 +1176,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), - }, pkt) - r.Stats().UnknownProtocolRcvdPackets.Increment() - return + // Since the iterator returns IPv6RawPayloadHeader for unknown Extension + // Header IDs this should never happen unless we missed a supported type + // here. + panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr)) + } } } @@ -1000,11 +1228,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre return addressEndpoint, nil } - snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { - return nil, err - } - addressEndpoint.SetKind(stack.PermanentTentative) if e.Enabled() { @@ -1013,6 +1236,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } } + snmc := header.SolicitedNodeAddr(addr.Address) + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) + } + return addressEndpoint, nil } @@ -1058,7 +1288,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + // The endpoint may have already left the multicast group. + if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1081,7 +1312,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { // // Precondition: e.mu must be read or write locked. func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { - return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) + return e.mu.addressableEndpointState.GetAddress(localAddr) } // MainAddress implements stack.AddressableEndpoint. @@ -1113,6 +1344,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) } +// getLinkLocalAddressRLocked returns a link-local address from the primary list +// of addresses, if one is available. +// +// See stack.PrimaryEndpointBehavior for more details about the primary list. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { + var linkLocalAddr tcpip.Address + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.IsAssigned(false /* allowExpired */) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + linkLocalAddr = addr + return false + } + } + return true + }) + return linkLocalAddr +} + // acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress // but with locking requirements. // @@ -1132,10 +1383,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { - return + return true } addr := addressEndpoint.AddressWithPrefix().Address @@ -1151,6 +1402,8 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address addressEndpoint: addressEndpoint, scope: scope, }) + + return true }) remoteScope, err := header.ScopeForIPv6Address(remoteAddr) @@ -1223,35 +1476,52 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.mu.mld.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mu.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mu.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { - stack *stack.Stack + stack *stack.Stack + options Options mu struct { sync.RWMutex @@ -1275,26 +1545,6 @@ type protocol struct { forwarding uint32 fragmentation *fragmentation.Fragmentation - - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher - - // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. - ndpConfigs NDPConfigurations - - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - - // autoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -1327,16 +1577,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp = ndpState{ - ep: e, - configs: p.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - e.mu.ndp.initializeTempAddrState() + e.mu.ndp.init(e) + e.mu.mld.init(e) + e.mu.Unlock() p.mu.Lock() defer p.mu.Unlock() @@ -1427,14 +1672,31 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - mtu -= header.IPv6MinimumSize - if mtu <= maxPayloadSize { - return mtu +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload MTU and the length of every IPv6 header. +// Note that this is different than the Payload Length field of the IPv6 header, +// which includes the length of the extension headers. +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv6MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState + } + + // As per RFC 7112 section 5, we should discard packets if their IPv6 header + // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not + // support PMTU discovery: + // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain + // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280 + // bytes ensures that the header chain length does not exceed the IPv6 + // minimum MTU. + if networkHeadersLen > header.IPv6MinimumMTU { + return 0, tcpip.ErrMalformedHeader + } + + networkMTU := linkMTU - uint32(networkHeadersLen) + if networkMTU > maxPayloadSize { + networkMTU = maxPayloadSize } - return maxPayloadSize + return networkMTU, nil } // Options holds options to configure a new protocol. @@ -1442,17 +1704,17 @@ type Options struct { // NDPConfigs is the default NDP configurations used by interfaces. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback + // AutoGenLinkLocal determines whether or not the stack attempts to + // auto-generate a link-local address for newly enabled non-loopback // NICs. // // Note, setting this to true does not mean that a link-local address is // assigned right away, or at all. If Duplicate Address Detection is enabled, // an address is only assigned if it successfully resolves. If it fails, no - // further attempts are made to auto-generate an IPv6 link-local adddress. + // further attempts are made to auto-generate a link-local adddress. // // The generated link-local address follows RFC 4291 Appendix A guidelines. - AutoGenIPv6LinkLocal bool + AutoGenLinkLocal bool // NDPDisp is the NDP event dispatcher that an integrator can provide to // receive NDP related events. @@ -1476,6 +1738,9 @@ type Options struct { // seed that is too small would reduce randomness and increase predictability, // defeating the purpose of temporary SLAAC addresses. TempIIDSeed []byte + + // MLD holds options for MLD. + MLD MLDOptions } // NewProtocolWithOptions returns an IPv6 network protocol. @@ -1487,17 +1752,13 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ - stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), - ids: ids, - hashIV: hashIV, - - ndpDisp: opts.NDPDisp, - ndpConfigs: opts.NDPConfigs, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + stack: s, + options: opts, + + ids: ids, + hashIV: hashIV, } + p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[*endpoint]struct{}) p.SetDefaultTTL(DefaultTTL) return p @@ -1509,23 +1770,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return NewProtocolWithOptions(Options{})(s) } -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are - // supported for outbound packets, their length should not affect the fragment - // MTU because they should only be transmitted once. - mtu -= uint32(pkt.NetworkHeader().View().Size()) - mtu -= header.IPv6FragmentHeaderSize - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - if mtu <= maxPayloadSize { - return mtu - } - return maxPayloadSize -} - func calculateFragmentReserve(pkt *stack.PacketBuffer) int { return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize } @@ -1558,23 +1802,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea fragPkt.NetworkProtocolNumber = ProtocolNumber originalIPHeadersLength := len(originalIPHeaders) - fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + + s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + M: more, + Identification: id, + }} + + fragmentIPHeadersLength := originalIPHeadersLength + s.Length() fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) } - fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) - fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) - fragmentHeader.Encode(&header.IPv6FragmentFields{ - M: more, - FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), - Identification: id, - NextHeader: uint8(transportProto), - }) + nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:]) + + fragmentIPHeaders.SetNextHeader(nextHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) return fragPkt, more } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 297868f24..5f07d3af8 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -18,12 +18,14 @@ import ( "encoding/hex" "fmt" "math" + "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" @@ -49,6 +51,7 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier) extraHeaderReserve = 50 ) @@ -66,18 +69,18 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived if got := stats.NeighborAdvert.Value(); got != want { t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) @@ -124,11 +127,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -238,7 +241,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) + e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } @@ -271,7 +274,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -571,6 +574,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { expectICMP: false, }, { + name: "unknown next header (first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, unknownHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6NextHeaderOffset, + }, + { + name: "unknown next header (not first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + unknownHdrID, 0, + 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { name: "destination with unknown option skippable action", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -753,11 +783,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { pointer: header.IPv6FixedHeaderSize, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, - shouldAccept: false, - }, - { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -825,7 +850,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -871,7 +896,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), udpPayload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } + + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength)) sum = header.Checksum(udpPayload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) @@ -882,16 +913,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - dstAddr := tcpip.Address(addr2) - if test.multicast { - dstAddr = header.IPv6AllNodesMulticastAddress - } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), - NextHeader: ipv6NextHdr, - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, + // We're lying about transport protocol here to be able to generate + // raw extension headers from the test definitions. + TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), + HopLimit: 255, + SrcAddr: addr1, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -980,9 +1009,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { udpPayload2Length = 128 // Used to test cases where the fragment blocks are not a multiple of // the fragment block size of 8 (RFC 8200 section 4.5). - udpPayload3Length = 127 - udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize - fragmentExtHdrLen = 8 + udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize + udpMaximumSizeMinus15 = header.UDPMaximumSize - 15 + fragmentExtHdrLen = 8 // Note, not all routing extension headers will be 8 bytes but this test // uses 8 byte routing extension headers for most sub tests. routingExtHdrLen = 8 @@ -1326,14 +1356,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+65520, + fragmentExtHdrLen+udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[:65520], + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, ), }, @@ -1342,14 +1372,17 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + udpMaximumSizeMinus15 & 0xff, + 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[65520:], + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, ), }, @@ -1357,6 +1390,47 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + (udpMaximumSizeMinus15 & 0xff) + 1, + 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { @@ -1844,7 +1918,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1875,10 +1949,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, + // We're lying about transport protocol here so that we can generate + // raw extension headers for the tests. + TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, }) vv := hdr.View().ToVectorisedView() @@ -1912,16 +1988,19 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - nicID = 1 - fragmentExtHdrLen = 8 + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_INVALID_IPV6_FRAGMENTS" ) - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr + payload []byte } tests := []struct { @@ -1929,31 +2008,62 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 + expectICMP bool + expectICMPType header.ICMPv6Type + expectICMPCode header.ICMPv6Code + expectICMPTypeSpecific uint32 }{ { + name: "fragment size is not a multiple of 8 and the M flag is true", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 9, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 0 >> 3, + M: true, + Identification: ident, + }, + payload: []byte(data)[:9], + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6PayloadLenOffset, + }, + { name: "fragments reassembled into a payload exceeding the max IPv6 payload size", fragments: []fragmentData{ { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, - []buffer.View{ - // Fragment extension header. - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, - ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, - 0, 0, 0, 1}), - // Payload length = 16 - payloadGen(16), - }, - ), + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, + M: false, + Identification: ident, + }, + payload: []byte(data)[:16], }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ }, } @@ -1964,33 +2074,39 @@ func TestInvalidIPv6Fragments(t *testing.T) { NewProtocol, }, }) - e := channel.New(0, 1500, linkAddr1) + e := channel.New(1, 1500, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + var expectICMPPayload buffer.View for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) vv := hdr.View().ToVectorisedView() - vv.Append(f.data) + vv.AppendView(f.payload) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - })) + }) + + if test.expectICMP { + expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { @@ -1999,6 +2115,280 @@ func TestInvalidIPv6Fragments(t *testing.T) { if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) } + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), + checker.ICMPv6( + checker.ICMPv6Type(test.expectICMPType), + checker.ICMPv6Code(test.expectICMPCode), + checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), + checker.ICMPv6Payload([]byte(expectICMPPayload)), + ), + ) + }) + } +} + +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + Clock: clock, + }) + + e := channel.New(1, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) + } + + clock.Advance(ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), + checker.ICMPv6Payload([]byte(firstFragmentSent)), + ), + ) }) } } @@ -2035,13 +2425,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, @@ -2056,17 +2443,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, @@ -2119,7 +2503,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -2136,7 +2520,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) @@ -2230,8 +2614,8 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", - mtu: 1280, + description: "No fragmentation", + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 0, payloadSize: 1000, @@ -2241,7 +2625,18 @@ var fragmentationTests = []struct { }, { description: "Fragmented", - mtu: 1280, + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv6MinimumMTU + 1, gso: nil, transHdrLen: 0, payloadSize: 2000, @@ -2262,7 +2657,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with gso none", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: &stack.GSO{Type: stack.GSONone}, transHdrLen: 0, payloadSize: 1400, @@ -2273,7 +2668,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with big header", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 100, payloadSize: 1200, @@ -2448,8 +2843,8 @@ func TestFragmentationErrors(t *testing.T) { wantError: tcpip.ErrAborted, }, { - description: "Error on packet with MTU smaller than transport header", - mtu: 1280, + description: "Error when MTU is smaller than transport header", + mtu: header.IPv6MinimumMTU, transHdrLen: 1500, payloadSize: 500, allowPackets: 0, @@ -2457,6 +2852,16 @@ func TestFragmentationErrors(t *testing.T) { mockError: nil, wantError: tcpip.ErrMessageTooLong, }, + { + description: "Error when MTU is smaller than IPv6 minimum MTU", + mtu: header.IPv6MinimumMTU - 1, + transHdrLen: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, + }, } for _, ft := range tests { @@ -2481,3 +2886,160 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv6Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + } + ipv6Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) + remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: true, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "TTL of three", + TTL: 3, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} + if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} + if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv6Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv6Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + } + + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) + icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv6EchoRequest) + icmp.SetCode(header.ICMPv6UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: test.TTL, + SrcAddr: remoteIPv6Addr1, + DstAddr: remoteIPv6Addr2, + }) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv6Addr1.Address), + checker.DstAddr(remoteIPv6Addr1), + checker.TTL(DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), + checker.ICMPv6Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo Request packet through outgoing NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv6Addr1), + checker.DstAddr(remoteIPv6Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest), + checker.ICMPv6Code(header.ICMPv6UnusedCode), + checker.ICMPv6Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go new file mode 100644 index 000000000..6f64b8462 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld.go @@ -0,0 +1,258 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited MLD reports. + // + // Obtained from RFC 2710 Section 7.10. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// MLDOptions holds options for MLD. +type MLDOptions struct { + // Enabled indicates whether MLD will be performed. + // + // When enabled, MLD may transmit MLD report and done messages when + // joining and leaving multicast groups respectively, and handle incoming + // MLD packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*mldState)(nil) + +// mldState is the per-interface MLD state. +// +// mldState.init MUST be called to initialize the MLD state. +type mldState struct { + // The IPv6 endpoint this mldState is for. + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + return err +} + +// init sets up an mldState struct, and is required to be called before using +// a new mldState. +// +// Must only be called once for the lifetime of mld. +func (mld *mldState) init(ep *endpoint) { + mld.ep = ep + mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + // No need to perform MLD on loopback interfaces since they don't have + // neighbouring nodes. + Enabled: ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback(), + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: mld, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv6AllNodesMulticastAddress, + }) +} + +// handleMulticastListenerQuery handles a query message. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) +} + +// handleMulticastListenerReport handles a report message. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress()) +} + +// joinGroup handles joining a new group and sending and scheduling the required +// messages. +// +// If the group is already joined, returns tcpip.ErrDuplicateAddress. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) joinGroup(groupAddress tcpip.Address) { + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress, !mld.ep.Enabled() /* dontInitialize */) +} + +// isInGroup returns true if the specified group has been joined locally. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { + return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Done message, if +// required. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of MLD, but remains +// joined locally. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) softLeaveAll() { + mld.genericMulticastProtocol.MakeAllNonMemberLocked() +} + +// initializeAll attemps to initialize the MLD state for each group that has +// been joined locally. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) initializeAll() { + mld.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) sendQueuedReports() { + mld.genericMulticastProtocol.SendQueuedReportsLocked() +} + +// writePacket assembles and sends an MLD packet. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { + sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + var mldStat *tcpip.StatCounter + switch mldType { + case header.ICMPv6MulticastListenerReport: + mldStat = sentStats.MulticastListenerReport + case header.ICMPv6MulticastListenerDone: + mldStat = sentStats.MulticastListenerDone + default: + panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) + } + + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) + icmp.SetType(mldType) + header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert + // option in a Hop-by-Hop Options header. + // + // However, this would cause problems with Duplicate Address Detection with + // the first address as MLD snooping switches may not send multicast traffic + // that DAD depends on to the node performing DAD without the MLD report, as + // documented in RFC 4816: + // + // Note that when a node joins a multicast address, it typically sends a + // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810] + // for the multicast address. In the case of Duplicate Address + // Detection, the MLD report message is required in order to inform MLD- + // snooping switches, rather than routers, to forward multicast packets. + // In the above description, the delay for joining the multicast address + // thus means delaying transmission of the corresponding MLD report + // message. Since the MLD specifications do not request a random delay + // to avoid race conditions, just delaying Neighbor Solicitation would + // cause congestion by the MLD report messages. The congestion would + // then prevent the MLD-snooping switches from working correctly and, as + // a result, prevent Duplicate Address Detection from working. The + // requirement to include the delay for the MLD report in this case + // avoids this scenario. [RFC3590] also talks about some interaction + // issues between Duplicate Address Detection and MLD, and specifies + // which source address should be used for the MLD report in this case. + // + // As per RFC 3590 section 4, we should still send out MLD reports with an + // unspecified source address if we do not have an assigned link-local + // address to use as the source address to ensure DAD works as expected on + // networks with MLD snooping switches: + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + localAddress := mld.ep.getLinkLocalAddressRLocked() + if len(localAddress) == 0 { + localAddress = header.IPv6Any + } + + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + + extensionHeaders := header.IPv6ExtHdrSerializer{ + header.IPv6SerializableHopByHopExtHdr{ + &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, + }, + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(), + Data: buffer.View(icmp).ToVectorisedView(), + }) + + mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.MLDHopLimit, + }, extensionHeaders) + if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return false, err + } + mldStat.Increment() + return localAddress != header.IPv6Any, nil +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go new file mode 100644 index 000000000..e2778b656 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -0,0 +1,297 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" +) + +var ( + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) + globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) +) + +func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { + t.Helper() + + checker.IPv6WithExtHdr(t, p, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(localAddress), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(mldType, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + +func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // The stack will join an address's solicited node multicast address when + // an address is added. An MLD report message should be sent for the + // solicited-node group. + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + + // The stack will leave an address's solicited node multicast address when + // an address is removed. An MLD done message should be sent for the + // solicited-node group. + if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a done message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + } +} + +func TestSendQueuedMLDReports(t *testing.T) { + const ( + nicID = 1 + maxReports = 2 + ) + + tests := []struct { + name string + dadTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD Disabled", + dadTransmits: 0, + retransmitTimer: 0, + }, + { + name: "DAD Enabled", + dadTransmits: 1, + retransmitTimer: time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dadTransmits, + RetransmitTimer: test.retransmitTimer, + }, + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + Clock: clock, + }) + + // Allow space for an extra packet so we can observe packets that were + // unexpectedly sent. + e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + resolveDAD := func(addr, snmc tcpip.Address) { + clock.Advance(dadResolutionTime) + if p, ok := e.Read(); !ok { + t.Fatal("expected DAD packet") + } else { + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr), + checker.NDPNSOptions(nil), + )) + } + } + + var reportCounter uint64 + reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + var doneCounter uint64 + doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + + // Joining a group without an assigned address should send an MLD report + // with the unspecified address. + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalMulticastAddr) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a global address should not send reports for the already joined + // group since we should only send queued reports when a link-local + // addres sis assigned. + // + // Note, we will still expect to send a report for the global address's + // solicited node address from the unspecified address as per RFC 3590 + // section 4. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) + } + if dadResolutionTime != 0 { + // Reports should not be sent when the address resolves. + resolveDAD(globalAddr, globalAddrSNMC) + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + } + // Leave the group since we don't care about the global address's + // solicited node multicast group membership. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) + } + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a link-local address should send a report for its solicited node + // address and globalMulticastAddr. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + } + if dadResolutionTime != 0 { + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + resolveDAD(linkLocalAddr, linkLocalAddrSNMC) + } + + // We expect two batches of reports to be sent (1 batch when the + // link-local address is assigned, and another after the maximum + // unsolicited report interval. + for i := 0; i < 2; i++ { + // We expect reports to be sent (one for globalMulticastAddr and another + // for linkLocalAddrSNMC). + reportCounter += maxReports + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + + addrs := map[tcpip.Address]bool{ + globalMulticastAddr: false, + linkLocalAddrSNMC: false, + } + for _ = range addrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) + } + + addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() + if seen, ok := addrs[addr]; !ok { + t.Fatalf("got unexpected packet destined to %s", addr) + } else if seen { + t.Fatalf("got another packet destined to %s", addr) + } + + addrs[addr] = true + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) + + clock.Advance(ipv6.UnsolicitedReportIntervalMax) + } + } + + // Should not send any more reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 40da011f8..d515eb622 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -20,6 +20,7 @@ import ( "math/rand" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + // The IPv6 endpoint this ndpState is for. ep *endpoint @@ -471,17 +475,8 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - rtrSolicit struct { - // The timer used to send the next router solicitation message. - timer tcpip.Timer - - // Used to let the Router Solicitation timer know that it has been stopped. - // - // Must only be read from or written to while protected by the lock of - // the IPv6 endpoint this ndpState is associated with. MUST be set when the - // timer is set. - done *bool - } + // The job used to send the next router solicitation message. + rtrSolicitJob *tcpip.Job // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -507,7 +502,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer tcpip.Timer + job *tcpip.Job // Used to let the DAD timer know that it has been stopped. // @@ -648,96 +643,73 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // Consider DAD to have resolved even if no DAD messages were actually // transmitted. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } + ndp.ep.onAddressAssignedLocked(addr) return nil } - var done bool - var timer tcpip.Timer - // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the IPv6 endpoint's lock. This is effectively - // the same as starting a goroutine but we use a timer that fires immediately - // so we can reset it for the next DAD iteration. - timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { - ndp.ep.mu.Lock() - defer ndp.ep.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the IPv6 endpoint lock and stopped - // DAD before this function obtained the NIC lock. Simply return here and - // do nothing further. - return - } + state := dadState{ + job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + state, ok := ndp.dad[addr] + if !ok { + panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID())) + } - if addressEndpoint.GetKind() != stack.PermanentTentative { - // The endpoint should still be marked as tentative since we are still - // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) - } + if addressEndpoint.GetKind() != stack.PermanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) + } - dadDone := remaining == 0 - - var err *tcpip.Error - if !dadDone { - // Use the unspecified address as the source address when performing DAD. - addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - - // Do not hold the lock when sending packets which may be a long running - // task or may block link address resolution. We know this is safe - // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the IPv6 endpoint. Note, - // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if - // the address was removed. - ndp.ep.mu.Unlock() - err = ndp.sendDADPacket(addr, addressEndpoint) - ndp.ep.mu.Lock() - addressEndpoint.DecRef() - } + dadDone := remaining == 0 - if done { - // If we reach this point, it means that DAD was stopped after we released - // the IPv6 endpoint's read lock and before we obtained the write lock. - return - } + var err *tcpip.Error + if !dadDone { + err = ndp.sendDADPacket(addr, addressEndpoint) + } - if dadDone { - // DAD has resolved. - addressEndpoint.SetKind(stack.Permanent) - } else if err == nil { - // DAD is not done and we had no errors when sending the last NDP NS, - // schedule the next DAD timer. - remaining-- - timer.Reset(ndp.configs.RetransmitTimer) - return - } + if dadDone { + // DAD has resolved. + addressEndpoint.SetKind(stack.Permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. + remaining-- + state.job.Schedule(ndp.configs.RetransmitTimer) + return + } - // At this point we know that either DAD is done or we hit an error sending - // the last NDP NS. Either way, clean up addr's DAD state and let the - // integrator know DAD has completed. - delete(ndp.dad, addr) + // At this point we know that either DAD is done or we hit an error + // sending the last NDP NS. Either way, clean up addr's DAD state and let + // the integrator know DAD has completed. + delete(ndp.dad, addr) - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) - } + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) + } - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) - } - }) + if dadDone { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } - ndp.dad[addr] = dadState{ - timer: timer, - done: &done, + ndp.ep.onAddressAssignedLocked(addr) + } + }), } + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + state.job.Schedule(0) + ndp.dad[addr] = state + return nil } @@ -745,55 +717,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // addr. // // addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -// -// The IPv6 endpoint that ndp belongs to MUST NOT be locked. func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - - // Route should resolve immediately since snmc is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is not - // locked, the NIC could have been disabled or removed by another goroutine. - if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { - return err - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) - } - - icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) - icmpData.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(addr) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(icmpData).ToVectorisedView(), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), + Data: buffer.View(icmp).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() return err } sent.NeighborSolicit.Increment() - return nil } @@ -812,18 +760,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { return } - if dad.timer != nil { - dad.timer.Stop() - dad.timer = nil - - *dad.done = true - dad.done = nil - } - + dad.job.Cancel() delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } @@ -846,7 +787,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we // only inform the dispatcher on configuration changes. We do nothing else // with the information. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { var configuration DHCPv6ConfigurationFromNDPRA switch { case ra.ManagedAddrConfFlag(): @@ -903,20 +844,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt := opt.(type) { case header.NDPRecursiveDNSServer: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } addrs, _ := opt.Addresses() - ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } domainNames, _ := opt.DomainNames() - ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -964,7 +905,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } @@ -976,7 +917,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1006,7 +947,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1047,7 +988,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1225,7 +1166,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return nil } @@ -1272,7 +1213,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt } dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { + if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, @@ -1676,7 +1617,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint } addressEndpoint.SetDeprecated(true) - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } @@ -1701,7 +1642,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1761,7 +1702,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1859,7 +1800,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicit.timer != nil { + if ndp.rtrSolicitJob != nil { // We are already soliciting routers. return } @@ -1876,56 +1817,14 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - var done bool - ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { - ndp.ep.mu.Lock() - if done { - // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the IPv6 endpoint lock and stopped - // solicitations. Simply return here and do nothing further. - ndp.ep.mu.Unlock() - return - } - + ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) - if addressEndpoint == nil { - // Incase this ends up creating a new temporary address, we need to hold - // onto the endpoint until a route is obtained. If we decrement the - // reference count before obtaing a route, the address's resources would - // be released and attempting to obtain a route after would fail. Once a - // route is obtainted, it is safe to decrement the reference count since - // obtaining a route increments the address's reference count. - addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - } - ndp.ep.mu.Unlock() - - localAddr := addressEndpoint.AddressWithPrefix().Address - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) - addressEndpoint.DecRef() - if err != nil { - return - } - defer r.Release() - - // Route should resolve immediately since - // header.IPv6AllRoutersMulticastAddress is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is - // not locked, the IPv6 endpoint could have been disabled or removed by - // another goroutine. - if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { - return - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) + localAddr := header.IPv6Any + if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + localAddr = addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1936,30 +1835,31 @@ func (ndp *ndpState) startSolicitingRouters() { // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. var optsSerializer header.NDPOptionsSerializer - if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + linkAddress := ndp.ep.nic.LinkAddress() + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) { optsSerializer = header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + header.NDPSourceLinkLayerAddressOption(linkAddress), } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(icmpData.NDPPayload()) + rs := header.NDPRouterSolicit(icmpData.MessageBody()) rs.Options().Serialize(optsSerializer) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. @@ -1969,21 +1869,12 @@ func (ndp *ndpState) startSolicitingRouters() { remaining-- } - ndp.ep.mu.Lock() - if done || remaining == 0 { - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil - } else if ndp.rtrSolicit.timer != nil { - // Note, we need to explicitly check to make sure that - // the timer field is not nil because if it was nil but - // we still reached this point, then we know the IPv6 endpoint - // was requested to stop soliciting routers so we don't - // need to send the next Router Solicitation message. - ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) + if remaining != 0 { + ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval) } - ndp.ep.mu.Unlock() }) + ndp.rtrSolicitJob.Schedule(delay) } // stopSolicitingRouters stops soliciting routers. If routers are not currently @@ -1991,22 +1882,28 @@ func (ndp *ndpState) startSolicitingRouters() { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicit.timer == nil { + if ndp.rtrSolicitJob == nil { // Nothing to do. return } - *ndp.rtrSolicit.done = true - ndp.rtrSolicit.timer.Stop() - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil + ndp.rtrSolicitJob.Cancel() + ndp.rtrSolicitJob = nil } -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) +func (ndp *ndpState) init(ep *endpoint) { + if ndp.dad != nil { + panic("attempted to initialize NDP state twice") + } + + ndp.ep = ep + ndp.configs = ep.protocol.options.NDPConfigs + ndp.dad = make(map[tcpip.Address]dadState) + ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) + ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index ac20f217e..05a0d95b2 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -73,6 +69,17 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig } t.Cleanup(ep.Close) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := llladdr.WithPrefix() + if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + addressEP.DecRef() + } + return s, ep } @@ -198,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -206,14 +213,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -304,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -312,23 +319,23 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -341,7 +348,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi if diff := cmp.Diff(existing, n); diff != "" { t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) } neighborByAddr[n.Addr] = n } @@ -368,7 +375,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi } if ok { - t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) } } }) @@ -573,11 +580,18 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(nicAddr) opts := ns.Options() opts.Serialize(test.nsOpts) @@ -585,14 +599,14 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -636,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != respNSDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -658,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(test.nsSrc) @@ -667,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, }) e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -692,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != test.naDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -763,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -771,14 +785,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -876,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -884,23 +898,23 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -913,13 +927,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test if diff := cmp.Diff(existing, n); diff != "" { t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) } neighborByAddr[n.Addr] = n } if neigh, ok := neighborByAddr[lladdr1]; ok { - t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) } if test.isValid { @@ -954,46 +968,37 @@ func TestNDPValidation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() // Create a stack with the assigned link-local address lladdr0 // and an endpoint to lladdr1. s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r + return s, ep } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - var extensions buffer.View + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { + var extHdrs header.IPv6ExtHdrSerializer if atomicFragment { - extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) - extensions[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) } + extHdrsLen := extHdrs.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, Data: payload.ToVectorisedView(), }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + len(extensions)), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, }) - if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { - t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) - } - ep.HandlePacket(r, pkt) + ep.HandlePacket(pkt) } var tllData [header.NDPLinkLayerAddressSize]byte @@ -1106,15 +1111,14 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + s, ep := setup(t) if isRouter { // Enabling forwarding makes the stack act as a router. s.SetForwarding(ProtocolNumber, true) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -1123,7 +1127,7 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -1144,7 +1148,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -1338,19 +1342,19 @@ func TestRouterAdvertValidation(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) + copy(pkt.MessageBody(), test.ndpPayload) payloadLength := hdr.UsedLength() pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, }) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid rxRA := stats.RouterAdvert diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go new file mode 100644 index 000000000..05d98a0a5 --- /dev/null +++ b/pkg/tcpip/network/multicast_group_test.go @@ -0,0 +1,1261 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "fmt" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + + ipv4Addr = tcpip.Address("\x0a\x00\x00\x01") + ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") + ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") + ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") + ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") + ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") + + igmpMembershipQuery = uint8(header.IGMPMembershipQuery) + igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) + igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) + igmpLeaveGroup = uint8(header.IGMPLeaveGroup) + mldQuery = uint8(header.ICMPv6MulticastListenerQuery) + mldReport = uint8(header.ICMPv6MulticastListenerReport) + mldDone = uint8(header.ICMPv6MulticastListenerDone) + + maxUnsolicitedReports = 2 +) + +var ( + // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the + // NIC will wait before sending an unsolicited report after joining a + // multicast group, in deciseconds. + unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { + const decisecond = time.Second / 10 + if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { + panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) + } + return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) + }() + + ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) +) + +// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet +// sent to the provided address with the passed fields set. +func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv6WithExtHdr(t, payload, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(ipv6Addr), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, + checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + +// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. +func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.SrcAddr(ipv4Addr), + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(header.IGMPType(igmpType)), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) + s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) + return e, s, clock +} + +func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { + t.Helper() + + igmpEnabled := v4 && mgpEnabled + mldEnabled := !v4 && mgpEnabled + + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: mldEnabled, + }, + }), + }, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) + } + + return s, clock +} + +// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join +// when it is created with an IPv6 address. +// +// To not interfere with tests, checkInitialIPv6Groups will leave the added +// address's solicited node multicast group so that the tests can all assume +// the NIC has not joined any IPv6 groups. +func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { + t.Helper() + + stats := s.Stats().ICMP.V6.PacketsSent + + reportCounter++ + if got := stats.MulticastListenerReport.Value(); got != reportCounter { + t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) + } + + // Leave the group to not affect the tests. This is fine since we are not + // testing DAD or the solicited node address specifically. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) + } + leaveCounter++ + if got := stats.MulticastListenerDone.Value(); got != leaveCounter { + t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + return reportCounter, leaveCounter +} + +// createAndInjectIGMPPacket creates and injects an IGMP packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: header.IGMPTTL, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(header.IGMPType(igmpType)) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// createAndInjectMLDPacket creates and injects an MLD packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { + icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize + buf := buffer.NewView(header.IPv6MinimumSize + icmpSize) + + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + HopLimit: header.MLDHopLimit, + TransportProtocol: header.ICMPv6ProtocolNumber, + SrcAddr: header.IPv4Any, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + icmp := header.ICMPv6(buf[header.IPv6MinimumSize:]) + icmp.SetType(header.ICMPv6Type(mldType)) + mld := header.MLD(icmp.MessageBody()) + mld.SetMaximumResponseDelay(uint16(maxRespDelay)) + mld.SetMulticastAddress(groupAddress) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + + e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestMGPDisabled tests that the multicast group protocol is not enabled by +// default. +func TestMGPDisabled(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) + + // This NIC may join multicast groups when it is enabled but since MGP is + // disabled, no reports should be sent. + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) + } + + // Inject a general query message. This should only trigger a report to be + // sent if the MGP was enabled. + test.rxQuery(e) + if got := test.receivedQueryStat(s).Value(); got != 1 { + t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + }) + } +} + +func TestMGPReceiveCounters(t *testing.T) { + tests := []struct { + name string + headerType uint8 + maxRespTime byte + groupAddress tcpip.Address + statCounter func(*stack.Stack) *tcpip.StatCounter + rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) + }{ + { + name: "IGMP Membership Query", + headerType: igmpMembershipQuery, + maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, + groupAddress: header.IPv4Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv1 Membership Report", + headerType: igmpv1MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V1MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv2 Membership Report", + headerType: igmpv2MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V2MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMP Leave Group", + headerType: igmpLeaveGroup, + maxRespTime: 0, + groupAddress: header.IPv4AllRoutersGroup, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.LeaveGroup + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "MLD Query", + headerType: mldQuery, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Report", + headerType: mldReport, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Done", + headerType: mldDone, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone + }, + rxMGPkt: createAndInjectMLDPacket, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) + + test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) + if got := test.statCounter(s).Value(); got != 1 { + t.Fatalf("got %s received = %d, want = 1", test.name, got) + } + }) + } +} + +// TestMGPJoinGroup tests that when explicitly joining a multicast group, the +// stack schedules and sends correct Membership Reports. +func TestMGPJoinGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } + + // Test joining a specific address explicitly and verify a Report is sent + // immediately. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + reportCounter++ + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Verify the second report is sent by the maximum unsolicited response + // interval. + p, ok := e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPLeaveGroup tests that when leaving a previously joined multicast +// group the stack sends a leave/done message. +func TestMGPLeaveGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + reportCounter++ + if got := test.sentReportStat(s).Value(); got != reportCounter { + t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving the group should trigger an leave/done message to be sent. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + leaveCounter++ + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a leave message to be sent") + } else { + test.validateLeave(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPQueryMessages tests that a report is sent in response to query +// messages. +func TestMGPQueryMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint, uint8, tcpip.Address) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + multicastAddr tcpip.Address + expectReport bool + }{ + { + name: "Unspecified", + multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), + expectReport: true, + }, + { + name: "Specified", + multicastAddr: test.multicastAddr, + expectReport: true, + }, + { + name: "Specified other address", + multicastAddr: func() tcpip.Address { + addrBytes := []byte(test.multicastAddr) + addrBytes[len(addrBytes)-1]++ + return tcpip.Address(addrBytes) + }(), + expectReport: false, + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + for i := 0; i < maxUnsolicitedReports; i++ { + sentReportStat := test.sentReportStat(s) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected %d-th report message to be sent", i) + } else { + test.validateReport(t, p) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + } + if t.Failed() { + t.FailNow() + } + + // Should not send any more packets until a query. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + // Receive a query message which should trigger a report to be sent at + // some time before the maximum response time if the report is + // targeted at the host. + const maxRespTime = 100 + test.rxQuery(e, maxRespTime, subTest.multicastAddr) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p.Pkt) + } + + if subTest.expectReport { + clock.Advance(test.maxRespTimeToDuration(maxRespTime)) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } + }) + } +} + +// TestMGPQueryMessages tests that no further reports or leave/done messages +// are sent after receiving a report. +func TestMGPReportMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + rxReport func(*channel.Endpoint) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Receiving a report for a group we joined should cancel any further + // reports. + test.rxReport(e) + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("sent unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving a group after getting a report should not send a leave/done + // message. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + clock.Advance(time.Hour) + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +func TestMGPWithNICLifecycle(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddrs []tcpip.Address + finalMulticastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) + validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) + getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, + finalMulticastAddr: ipv4MulticastAddr3, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { + t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) + } + addr := header.IGMP(ipv4.Payload()).GroupAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, + finalMulticastAddr: ipv6MulticastAddr3, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, addr, mldReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + + ipv6HeaderIter := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var transport header.IPv6RawPayloadHeader + for { + h, done, err := ipv6HeaderIter.Next() + if err != nil { + t.Fatalf("ipv6HeaderIter.Next(): %s", err) + } + if done { + t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) + } + if t, ok := h.(header.IPv6RawPayloadHeader); ok { + transport = t + break + } + } + + if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { + t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) + } + icmpv6 := header.ICMPv6(transport.Buf.ToView()) + if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { + t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) + } + addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + sentReportStat := test.sentReportStat(s) + for _, a := range test.multicastAddrs { + if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected a report message to be sent for %s", a) + } else { + test.validateReport(t, p, a) + } + } + if t.Failed() { + t.FailNow() + } + + // Leave messages should be sent for the joined groups when the NIC is + // disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + sentLeaveStat := test.sentLeaveStat(s) + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + + test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Reports should be sent for the joined groups when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter += uint64(len(test.multicastAddrs)) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i, _ := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) report message to be sent", i) + } + + test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Joining/leaving a group while disabled should not send any messages. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + for i, _ := range test.multicastAddrs { + if _, ok := e.Read(); !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + } + for _, a := range test.multicastAddrs { + if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) + } + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) + } + } + if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) + } + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) + } + + // A report should only be sent for the group we last joined after + // enabling the NIC since the original groups were all left. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPDisabledOnLoopback tests that the multicast group protocol is not +// performed on loopback interfaces since they have no neighbours. +func TestMGPDisabledOnLoopback(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) + + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + }) + } +} diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 7cc52985e..5c3363759 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st return n, nil } -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - - return nil -} - // Attach implements LinkEndpoint.Attach. func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 8e0ee1cd7..1c2afd554 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -148,10 +148,6 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - log.Fatal(err) - } - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) if err != nil { log.Fatal(err) diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go new file mode 100644 index 000000000..b60a5fd76 --- /dev/null +++ b/pkg/tcpip/socketops.go @@ -0,0 +1,364 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" +) + +// SocketOptionsHandler holds methods that help define endpoint specific +// behavior for socket level socket options. These must be implemented by +// endpoints to get notified when socket level options are set. +type SocketOptionsHandler interface { + // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint. + OnReuseAddressSet(v bool) + + // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint. + OnReusePortSet(v bool) + + // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint. + OnKeepAliveSet(v bool) + + // OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint. + // Note that v will be the inverse of TCP_NODELAY option. + OnDelayOptionSet(v bool) + + // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint. + OnCorkOptionSet(v bool) + + // LastError is invoked when SO_ERROR is read for an endpoint. + LastError() *Error +} + +// DefaultSocketOptionsHandler is an embeddable type that implements no-op +// implementations for SocketOptionsHandler methods. +type DefaultSocketOptionsHandler struct{} + +var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil) + +// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet. +func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {} + +// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet. +func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {} + +// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet. +func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {} + +// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet. +func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} + +// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet. +func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} + +// LastError implements SocketOptionsHandler.LastError. +func (*DefaultSocketOptionsHandler) LastError() *Error { + return nil +} + +// SocketOptions contains all the variables which store values for SOL_SOCKET, +// SOL_IP, SOL_IPV6 and SOL_TCP level options. +// +// +stateify savable +type SocketOptions struct { + handler SocketOptionsHandler + + // These fields are accessed and modified using atomic operations. + + // broadcastEnabled determines whether datagram sockets are allowed to + // send packets to a broadcast address. + broadcastEnabled uint32 + + // passCredEnabled determines whether SCM_CREDENTIALS socket control + // messages are enabled. + passCredEnabled uint32 + + // noChecksumEnabled determines whether UDP checksum is disabled while + // transmitting for this socket. + noChecksumEnabled uint32 + + // reuseAddressEnabled determines whether Bind() should allow reuse of + // local address. + reuseAddressEnabled uint32 + + // reusePortEnabled determines whether to permit multiple sockets to be + // bound to an identical socket address. + reusePortEnabled uint32 + + // keepAliveEnabled determines whether TCP keepalive is enabled for this + // socket. + keepAliveEnabled uint32 + + // multicastLoopEnabled determines whether multicast packets sent over a + // non-loopback interface will be looped back. Analogous to inet->mc_loop. + multicastLoopEnabled uint32 + + // receiveTOSEnabled is used to specify if the TOS ancillary message is + // passed with incoming packets. + receiveTOSEnabled uint32 + + // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary + // message is passed with incoming packets. + receiveTClassEnabled uint32 + + // receivePacketInfoEnabled is used to specify if more inforamtion is + // provided with incoming packets such as interface index and address. + receivePacketInfoEnabled uint32 + + // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets + // being written have an IP header and the endpoint should not attach an IP + // header. + hdrIncludedEnabled uint32 + + // v6OnlyEnabled is used to determine whether an IPv6 socket is to be + // restricted to sending and receiving IPv6 packets only. + v6OnlyEnabled uint32 + + // quickAckEnabled is used to represent the value of TCP_QUICKACK option. + // It currently does not have any effect on the TCP endpoint. + quickAckEnabled uint32 + + // delayOptionEnabled is used to specify if data should be sent out immediately + // by the transport protocol. For TCP, it determines if the Nagle algorithm + // is on or off. + delayOptionEnabled uint32 + + // corkOptionEnabled is used to specify if data should be held until segments + // are full by the TCP transport protocol. + corkOptionEnabled uint32 + + // receiveOriginalDstAddress is used to specify if the original destination of + // the incoming packet should be returned as an ancillary message. + receiveOriginalDstAddress uint32 + + // mu protects the access to the below fields. + mu sync.Mutex `state:"nosave"` + + // linger determines the amount of time the socket should linger before + // close. We currently implement this option for TCP socket only. + linger LingerOption +} + +// InitHandler initializes the handler. This must be called before using the +// socket options utility. +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { + so.handler = handler +} + +func storeAtomicBool(addr *uint32, v bool) { + var val uint32 + if v { + val = 1 + } + atomic.StoreUint32(addr, val) +} + +// GetBroadcast gets value for SO_BROADCAST option. +func (so *SocketOptions) GetBroadcast() bool { + return atomic.LoadUint32(&so.broadcastEnabled) != 0 +} + +// SetBroadcast sets value for SO_BROADCAST option. +func (so *SocketOptions) SetBroadcast(v bool) { + storeAtomicBool(&so.broadcastEnabled, v) +} + +// GetPassCred gets value for SO_PASSCRED option. +func (so *SocketOptions) GetPassCred() bool { + return atomic.LoadUint32(&so.passCredEnabled) != 0 +} + +// SetPassCred sets value for SO_PASSCRED option. +func (so *SocketOptions) SetPassCred(v bool) { + storeAtomicBool(&so.passCredEnabled, v) +} + +// GetNoChecksum gets value for SO_NO_CHECK option. +func (so *SocketOptions) GetNoChecksum() bool { + return atomic.LoadUint32(&so.noChecksumEnabled) != 0 +} + +// SetNoChecksum sets value for SO_NO_CHECK option. +func (so *SocketOptions) SetNoChecksum(v bool) { + storeAtomicBool(&so.noChecksumEnabled, v) +} + +// GetReuseAddress gets value for SO_REUSEADDR option. +func (so *SocketOptions) GetReuseAddress() bool { + return atomic.LoadUint32(&so.reuseAddressEnabled) != 0 +} + +// SetReuseAddress sets value for SO_REUSEADDR option. +func (so *SocketOptions) SetReuseAddress(v bool) { + storeAtomicBool(&so.reuseAddressEnabled, v) + so.handler.OnReuseAddressSet(v) +} + +// GetReusePort gets value for SO_REUSEPORT option. +func (so *SocketOptions) GetReusePort() bool { + return atomic.LoadUint32(&so.reusePortEnabled) != 0 +} + +// SetReusePort sets value for SO_REUSEPORT option. +func (so *SocketOptions) SetReusePort(v bool) { + storeAtomicBool(&so.reusePortEnabled, v) + so.handler.OnReusePortSet(v) +} + +// GetKeepAlive gets value for SO_KEEPALIVE option. +func (so *SocketOptions) GetKeepAlive() bool { + return atomic.LoadUint32(&so.keepAliveEnabled) != 0 +} + +// SetKeepAlive sets value for SO_KEEPALIVE option. +func (so *SocketOptions) SetKeepAlive(v bool) { + storeAtomicBool(&so.keepAliveEnabled, v) + so.handler.OnKeepAliveSet(v) +} + +// GetMulticastLoop gets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) GetMulticastLoop() bool { + return atomic.LoadUint32(&so.multicastLoopEnabled) != 0 +} + +// SetMulticastLoop sets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) SetMulticastLoop(v bool) { + storeAtomicBool(&so.multicastLoopEnabled, v) +} + +// GetReceiveTOS gets value for IP_RECVTOS option. +func (so *SocketOptions) GetReceiveTOS() bool { + return atomic.LoadUint32(&so.receiveTOSEnabled) != 0 +} + +// SetReceiveTOS sets value for IP_RECVTOS option. +func (so *SocketOptions) SetReceiveTOS(v bool) { + storeAtomicBool(&so.receiveTOSEnabled, v) +} + +// GetReceiveTClass gets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) GetReceiveTClass() bool { + return atomic.LoadUint32(&so.receiveTClassEnabled) != 0 +} + +// SetReceiveTClass sets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) SetReceiveTClass(v bool) { + storeAtomicBool(&so.receiveTClassEnabled, v) +} + +// GetReceivePacketInfo gets value for IP_PKTINFO option. +func (so *SocketOptions) GetReceivePacketInfo() bool { + return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0 +} + +// SetReceivePacketInfo sets value for IP_PKTINFO option. +func (so *SocketOptions) SetReceivePacketInfo(v bool) { + storeAtomicBool(&so.receivePacketInfoEnabled, v) +} + +// GetHeaderIncluded gets value for IP_HDRINCL option. +func (so *SocketOptions) GetHeaderIncluded() bool { + return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0 +} + +// SetHeaderIncluded sets value for IP_HDRINCL option. +func (so *SocketOptions) SetHeaderIncluded(v bool) { + storeAtomicBool(&so.hdrIncludedEnabled, v) +} + +// GetV6Only gets value for IPV6_V6ONLY option. +func (so *SocketOptions) GetV6Only() bool { + return atomic.LoadUint32(&so.v6OnlyEnabled) != 0 +} + +// SetV6Only sets value for IPV6_V6ONLY option. +// +// Preconditions: the backing TCP or UDP endpoint must be in initial state. +func (so *SocketOptions) SetV6Only(v bool) { + storeAtomicBool(&so.v6OnlyEnabled, v) +} + +// GetQuickAck gets value for TCP_QUICKACK option. +func (so *SocketOptions) GetQuickAck() bool { + return atomic.LoadUint32(&so.quickAckEnabled) != 0 +} + +// SetQuickAck sets value for TCP_QUICKACK option. +func (so *SocketOptions) SetQuickAck(v bool) { + storeAtomicBool(&so.quickAckEnabled, v) +} + +// GetDelayOption gets inverted value for TCP_NODELAY option. +func (so *SocketOptions) GetDelayOption() bool { + return atomic.LoadUint32(&so.delayOptionEnabled) != 0 +} + +// SetDelayOption sets inverted value for TCP_NODELAY option. +func (so *SocketOptions) SetDelayOption(v bool) { + storeAtomicBool(&so.delayOptionEnabled, v) + so.handler.OnDelayOptionSet(v) +} + +// GetCorkOption gets value for TCP_CORK option. +func (so *SocketOptions) GetCorkOption() bool { + return atomic.LoadUint32(&so.corkOptionEnabled) != 0 +} + +// SetCorkOption sets value for TCP_CORK option. +func (so *SocketOptions) SetCorkOption(v bool) { + storeAtomicBool(&so.corkOptionEnabled, v) + so.handler.OnCorkOptionSet(v) +} + +// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) GetReceiveOriginalDstAddress() bool { + return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0 +} + +// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { + storeAtomicBool(&so.receiveOriginalDstAddress, v) +} + +// GetLastError gets value for SO_ERROR option. +func (so *SocketOptions) GetLastError() *Error { + return so.handler.LastError() +} + +// GetOutOfBandInline gets value for SO_OOBINLINE option. +func (*SocketOptions) GetOutOfBandInline() bool { + return true +} + +// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not +// support disabling this option. +func (*SocketOptions) SetOutOfBandInline(bool) {} + +// GetLinger gets value for SO_LINGER option. +func (so *SocketOptions) GetLinger() LingerOption { + so.mu.Lock() + linger := so.linger + so.mu.Unlock() + return linger +} + +// SetLinger sets value for SO_LINGER option. +func (so *SocketOptions) SetLinger(linger LingerOption) { + so.mu.Lock() + so.linger = linger + so.mu.Unlock() +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index d09ebe7fa..9cc6074da 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "most_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -112,7 +112,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], - shard_count = 20, + shard_count = most_shards, deps = [ ":stack", "//pkg/rand", @@ -120,6 +120,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", @@ -131,7 +132,6 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 4d3acab96..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) var _ AddressableEndpoint = (*AddressableEndpointState)(nil) // AddressableEndpointState is an implementation of an AddressableEndpoint. @@ -37,10 +36,6 @@ type AddressableEndpointState struct { endpoints map[tcpip.Address]*addressState primary []*addressState - - // groups holds the mapping between group addresses and the number of times - // they have been joined. - groups map[tcpip.Address]uint32 } } @@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { a.mu.Lock() defer a.mu.Unlock() a.mu.endpoints = make(map[tcpip.Address]*addressState) - a.mu.groups = make(map[tcpip.Address]uint32) -} - -// ReadOnlyAddressableEndpointState provides read-only access to an -// AddressableEndpointState. -type ReadOnlyAddressableEndpointState struct { - inner *AddressableEndpointState } -// AddrOrMatching returns an endpoint for the passed address that is consisdered -// bound to the wrapped AddressableEndpointState. +// GetAddress returns the AddressEndpoint for the passed address. // -// If addr is an exact match with an existing address, that address is returned. -// Otherwise, f is called with each address and the address that f returns true -// for is returned. +// GetAddress does not increment the address's reference count or check if the +// address is considered bound to the endpoint. // -// Returns nil of no address matches. -func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - - if ep, ok := m.inner.mu.endpoints[addr]; ok { - if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { - return ep - } - } - - for _, ep := range m.inner.mu.endpoints { - if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { - return ep - } - } - - return nil -} - -// Lookup returns the AddressEndpoint for the passed address. -// -// Returns nil if the passed address is not associated with the -// AddressableEndpointState. -func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Returns nil if the passed address is not associated with the endpoint. +func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() - ep, ok := m.inner.mu.endpoints[addr] + ep, ok := a.mu.endpoints[addr] if !ok { return nil } return ep } -// ForEach calls f for each address pair. +// ForEachEndpoint calls f for each address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() - for _, ep := range m.inner.mu.endpoints { + for _, ep := range a.mu.endpoints { if !f(ep) { return } @@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) // ForEachPrimaryEndpoint calls f for each primary address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - for _, ep := range m.inner.mu.primary { - f(ep) - } -} +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() -// ReadOnly returns a readonly reference to a. -func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { - return ReadOnlyAddressableEndpointState{inner: a} + for _, ep := range a.mu.primary { + if !f(ep) { + return + } + } } func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { @@ -272,6 +233,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState = &addressState{ addressableEndpointState: a, addr: addr, + // Cache the subnet in addrState to avoid calls to addr.Subnet() as that + // results in allocations on every call. + subnet: addr.Subnet(), } a.mu.endpoints[addr.Address] = addrState addrState.mu.Lock() @@ -332,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { a.mu.Lock() defer a.mu.Unlock() - - if _, ok := a.mu.groups[addr]; ok { - panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) - } - return a.removePermanentAddressLocked(addr) } @@ -361,6 +320,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * return tcpip.ErrInvalidEndpointState } + a.mu.Lock() + defer a.mu.Unlock() return a.removePermanentEndpointLocked(addrState) } @@ -466,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad return deprecatedEndpoint } -// AcquireAssignedAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { +// AcquireAssignedAddressOrMatching returns an address endpoint that is +// considered assigned to the addressable endpoint. +// +// If the address is an exact match with an existing address, that address is +// returned. Otherwise, if f is provided, f is called with each address and +// the address that f returns true for is returned. +// +// If there is no matching address, a temporary address will be returned if +// allowTemp is true. +// +// Regardless how the address was obtained, it will be acquired before it is +// returned. +func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { a.mu.Lock() defer a.mu.Unlock() @@ -483,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return addrState } + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } + } + } + if !allowTemp { return nil } @@ -515,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return ep } +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB) +} + // AcquireOutgoingPrimaryAddress implements AddressableEndpoint. func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { a.mu.RLock() @@ -583,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi return addrs } -// JoinGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) - if err != nil { - return false, err - } - // We have no need for the address endpoint. - a.decAddressRefLocked(ep) - } - - a.mu.groups[group] = joins + 1 - return !ok, nil -} - -// LeaveGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - return false, tcpip.ErrBadLocalAddress - } - - if joins == 1 { - a.removeGroupAddressLocked(group) - delete(a.mu.groups, group) - return true, nil - } - - a.mu.groups[group] = joins - 1 - return false, nil -} - -// IsInGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { - a.mu.RLock() - defer a.mu.RUnlock() - _, ok := a.mu.groups[group] - return ok -} - -func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { - if err := a.removePermanentAddressLocked(group); err != nil { - // removePermanentEndpointLocked would only return an error if group is - // not bound to the addressable endpoint, but we know it MUST be assigned - // since we have group in our map of groups. - panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) - } -} - // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() defer a.mu.Unlock() - for group := range a.mu.groups { - a.removeGroupAddressLocked(group) - } - a.mu.groups = make(map[tcpip.Address]uint32) - for _, ep := range a.mu.endpoints { // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is // not a permanent address. @@ -664,7 +588,7 @@ var _ AddressEndpoint = (*addressState)(nil) type addressState struct { addressableEndpointState *AddressableEndpointState addr tcpip.AddressWithPrefix - + subnet tcpip.Subnet // Lock ordering (from outer to inner lock ordering): // // AddressableEndpointState.mu @@ -684,6 +608,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr } +// Subnet implements AddressEndpoint. +func (a *addressState) Subnet() tcpip.Subnet { + return a.subnet +} + // GetKind implements AddressEndpoint. func (a *addressState) GetKind() AddressKind { a.mu.RLock() diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 26787d0a3..140f146f6 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep.DecRef() } - group := tcpip.Address("\x02") - if added, err := s.JoinGroup(group); err != nil { - t.Fatalf("s.JoinGroup(%s): %s", group, err) - } else if !added { - t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) - } - if !s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) - } - s.Cleanup() - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } - } - if s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 0cd1da11f..9a17efcba 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.Addr - replyTID.srcPort = rt.Port + replyTID.srcAddr = address + replyTID.srcPort = port var manip manipType switch hook { case Prerouting: @@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + length := uint16(len(tcpHeader) + pkt.Data.Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) - } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumVV(pkt.Data, xsum) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index cf042309e..5ec9b3411 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -73,9 +73,31 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) +func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { + netHdr := pkt.NetworkHeader().View() + _, dst := f.proto.ParseAddresses(netHdr) + + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) + return + } + + r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) + if err != nil { + return + } + defer r.Release() + + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + pkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: vv.ToView().ToVectorisedView(), + }) + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + _ = r.WriteHeaderIncludedPacket(pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + // The network header should not already be populated. + if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { + return tcpip.ErrMalformedHeader + } + + return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() { // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { + stack *Stack + addrCache *linkAddrCache neigh *neighborCache addrResolveDelay time.Duration @@ -178,7 +207,7 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) @@ -280,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, + RemoteLinkAddress: r.RemoteLinkAddress(), LocalLinkAddress: r.LocalLinkAddress, Pkt: pkt, } @@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), - } - - select { - case e.C <- p: - default: - } - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} @@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { + proto.stack = s + return proto + }}, UseNeighborCache: useNeighborCache, }) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 8d6d9a7f1..2d8c883cd 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -22,30 +22,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -// tableID is an index into IPTables.tables. -type tableID int +// TableID identifies a specific table. +type TableID int +// Each value identifies a specific table. const ( - natID tableID = iota - mangleID - filterID - numTables + NATID TableID = iota + MangleID + FilterID + NumTables ) -// Table names. -const ( - NATTable = "nat" - MangleTable = "mangle" - FilterTable = "filter" -) - -// nameToID is immutable. -var nameToID = map[string]tableID{ - NATTable: natID, - MangleTable: mangleID, - FilterTable: filterID, -} - // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 @@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - v4Tables: [numTables]Table{ - natID: Table{ + v4Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -81,7 +68,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -99,7 +86,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -122,8 +109,8 @@ func DefaultTables() *IPTables { }, }, }, - v6Tables: [numTables]Table{ - natID: Table{ + v6Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -146,7 +133,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -164,7 +151,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -187,10 +174,10 @@ func DefaultTables() *IPTables { }, }, }, - priorities: [NumHooks][]tableID{ - Prerouting: []tableID{mangleID, natID}, - Input: []tableID{natID, filterID}, - Output: []tableID{mangleID, natID, filterID}, + priorities: [NumHooks][]TableID{ + Prerouting: []TableID{MangleID, NATID}, + Input: []TableID{NATID, FilterID}, + Output: []TableID{MangleID, NATID, FilterID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -229,26 +216,20 @@ func EmptyNATTable() Table { } } -// GetTable returns a table by name. -func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - id, ok := nameToID[name] - if !ok { - return Table{}, false - } +// GetTable returns a table with the given id and IP version. It panics when an +// invalid id is provided. +func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { it.mu.RLock() defer it.mu.RUnlock() if ipv6 { - return it.v6Tables[id], true + return it.v6Tables[id] } - return it.v4Tables[id], true + return it.v4Tables[id] } -// ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - id, ok := nameToID[name] - if !ok { - return tcpip.ErrInvalidOptionValue - } +// ReplaceTable replaces or inserts table by name. It panics when an invalid id +// is provided. +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to // check the NAT table. - if tableID == natID && pkt.NatDone { + if tableID == NATID && pkt.NatDone { continue } var table Table diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 538c4625d..d63e9757c 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -26,13 +28,6 @@ type AcceptTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (at *AcceptTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: at.NetworkProtocol, - } -} - // Action implements Target.Action. func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 @@ -44,22 +39,11 @@ type DropTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (dt *DropTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: dt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } -// ErrorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const ErrorTargetName = "ERROR" - // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. type ErrorTarget struct { @@ -67,14 +51,6 @@ type ErrorTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (et *ErrorTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: et.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") @@ -90,14 +66,6 @@ type UserChainTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (uc *UserChainTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: uc.NetworkProtocol, - } -} - // Action implements Target.Action. func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") @@ -110,50 +78,39 @@ type ReturnTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *ReturnTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } -// RedirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const RedirectTargetName = "REDIRECT" - -// RedirectTarget redirects the packet by modifying the destination port/IP. +// RedirectTarget redirects the packet to this machine by modifying the +// destination port/IP. Outgoing packets are redirected to the loopback device, +// and incoming packets are redirected to the incoming interface (rather than +// forwarded). +// // TODO(gvisor.dev/issue/170): Other flags need to be added after we support // them. type RedirectTarget struct { - // Addr indicates address used to redirect. - Addr tcpip.Address - - // Port indicates port used to redirect. + // Port indicates port used to redirect. It is immutable. Port uint16 - // NetworkProtocol is the network protocol the target is used with. + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *RedirectTarget) ID() TargetID { - return TargetID{ - Name: RedirectTargetName, - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for PREROUTING and calls pkt.Clone(), neither +// implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + address = tcpip.Address([]byte{127, 0, 0, 1}) } else { - rt.Addr = header.IPv6Loopback + address = header.IPv6Loopback } case Prerouting: - rt.Addr = address + // No-op, as address is already set correctly. default: panic("redirect target is supported only on output and prerouting hooks") } @@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) + netHeader := pkt.Network() + netHeader.SetDestinationAddress(address) // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(protocol, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumVV(pkt.Data, xsum) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - pkt.Network().SetDestinationAddress(rt.Addr) - // After modification, IPv4 packets need a valid checksum. if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { netHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7b3f3e88b..4b86c1be9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -37,7 +37,6 @@ import ( // ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> type Hook uint -// These values correspond to values in include/uapi/linux/netfilter.h. const ( // Prerouting happens before a packet is routed to applications or to // be forwarded. @@ -86,8 +85,8 @@ type IPTables struct { mu sync.RWMutex // v4Tables and v6tables map tableIDs to tables. They hold builtin // tables only, not user tables. mu must be locked for accessing. - v4Tables [numTables]Table - v6Tables [numTables]Table + v4Tables [NumTables]Table + v6Tables [NumTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. @@ -96,7 +95,7 @@ type IPTables struct { // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. - priorities [NumHooks][]tableID + priorities [NumHooks][]TableID connections ConnTrack @@ -104,6 +103,24 @@ type IPTables struct { reaperDone chan struct{} } +// VisitTargets traverses all the targets of all tables and replaces each with +// transform(target). +func (it *IPTables) VisitTargets(transform func(Target) Target) { + it.mu.Lock() + defer it.mu.Unlock() + + for tid := range it.v4Tables { + for i, rule := range it.v4Tables[tid].Rules { + it.v4Tables[tid].Rules[i].Target = transform(rule.Target) + } + } + for tid := range it.v6Tables { + for i, rule := range it.v6Tables[tid].Rules { + it.v6Tables[tid].Rules[i].Target = transform(rule.Target) + } + } +} + // A Table defines a set of chains and hooks into the network stack. // // It is a list of Rules, entry points (BuiltinChains), and error handlers @@ -169,7 +186,6 @@ type IPHeaderFilter struct { // CheckProtocol determines whether the Protocol field should be // checked during matching. - // TODO(gvisor.dev/issue/3549): Check this field during matching. CheckProtocol bool // Dst matches the destination IP address. @@ -309,23 +325,8 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } -// A TargetID uniquely identifies a target. -type TargetID struct { - // Name is the target name as stored in the xt_entry_target struct. - Name string - - // NetworkProtocol is the protocol to which the target applies. - NetworkProtocol tcpip.NetworkProtocolNumber - - // Revision is the version of the target. - Revision uint8 -} - // A Target is the interface for taking an action for a packet. type Target interface { - // ID uniquely identifies the Target. - ID() TargetID - // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 6f73a0ce4..c9b13cd0e 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { return addr, nil, nil @@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.linkAddr, entry.done, tcpip.ErrWouldBlock @@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 33806340e..d2e37f38d 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -49,8 +49,8 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 73a01c2dd..03d7b4e0d 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) { } // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { t.Fatalf("got NeighborSolicit = %d, want = 0", got) } } @@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) { if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) { } else if r.LocalAddress != addr1 { t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) { } // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } @@ -533,8 +540,8 @@ func TestDADResolve(t *testing.T) { // Make sure the right remote link address is used. snmc := header.SolicitedNodeAddr(addr1) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { + t.Errorf("got remote link address = %s, want = %s", got, want) } // Check NDP NS packet. @@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) } @@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) @@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, @@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate an address conflict. test.rxPkt(e, addr1) - stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) + stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) if got := stat.Value(); got != 1 { t.Fatalf("got stat = %d, want = 1", got) } @@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) { } // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { t.Errorf("got NeighborSolicit = %d, want <= 1", got) } }) @@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - raPayload := pkt.NDPPayload() + raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. binary.BigEndian.PutUint16(raPayload[2:], rl) @@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, }) return stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { NDPConfigs: ipv6.NDPConfigurations{ AutoGenTempGlobalAddresses: true, }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, })}, }) @@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(addr); err != nil { t.Fatalf("ep.Connect(%+v): %s", addr, err) } @@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Bind(addr); err != nil { t.Fatalf("ep.Bind(%+v): %s", addr, err) } @@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) @@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName @@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, + AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + // Make sure the right remote link address is used. + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want { + t.Errorf("got remote link address = %s, want = %s", got, want) + } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet") - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) - remaining-- + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) } + } - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) - waitForPkt(defaultAsyncPositiveEventTimeout) - } else { - waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt) } + } - // Make sure the counter got properly - // incremented. - if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } + + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } } func TestStopStartSolicitingRouters(t *testing.T) { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 4df288798..317f6871d 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "time" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -25,9 +24,16 @@ import ( const neighborCacheSize = 512 // max entries per interface +// NeighborStats holds metrics for the neighbor table. +type NeighborStats struct { + // FailedEntryLookups counts the number of lookups performed on an entry in + // Failed state. + FailedEntryLookups *tcpip.StatCounter +} + // neighborCache maps IP addresses to link addresses. It uses the Least // Recently Used (LRU) eviction strategy to implement a bounded cache for -// dynmically acquired entries. It contains the state machine and configuration +// dynamically acquired entries. It contains the state machine and configuration // for running Neighbor Unreachability Detection (NUD). // // There are two types of entries in the neighbor cache: @@ -68,7 +74,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -84,7 +90,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -111,28 +117,31 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // provided, it will be notified when address resolution is complete (success // or not). // +// If specified, the local address must be an address local to the interface the +// neighbor cache belongs to. The local address is the source address of a +// packet prompting NUD/link address resolution. +// // If address resolution is required, ErrNoLinkAddress and a notification // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), + Addr: remoteAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: 0, } return e, nil, nil } - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() switch s := entry.neigh.State; s { case Stale: - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) fallthrough case Reachable, Static, Delay, Probe: // As per RFC 4861 section 7.3.3: @@ -152,7 +161,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress @@ -173,14 +182,15 @@ func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(n.cache)) n.mu.RLock() + defer n.mu.RUnlock() + + entries := make([]NeighborEntry, 0, len(n.cache)) for _, entry := range n.cache { entry.mu.RLock() entries = append(entries, entry.neigh) entry.mu.RUnlock() } - n.mu.RUnlock() return entries } @@ -207,7 +217,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd } else { // Static entry found with the same address but different link address. entry.neigh.LinkAddr = linkAddr - entry.dispatchChangeEventLocked(entry.neigh.State) + entry.dispatchChangeEventLocked() entry.mu.Unlock() return } @@ -220,11 +230,12 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) - n.cache[addr] = entry + n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } // removeEntryLocked removes the specified entry from the neighbor cache. +// +// Prerequisite: n.mu and entry.mu MUST be locked. func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { if entry.neigh.State != Static { n.dynamic.lru.Remove(entry) @@ -292,8 +303,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) +func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index fcd54ed83..732a299f7 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -61,39 +61,39 @@ const ( ) // entryDiffOpts returns the options passed to cmp.Diff to compare neighbor -// entries. The UpdatedAt field is ignored due to a lack of a deterministic -// method to predict the time that an event will be dispatched. +// entries. The UpdatedAtNanos field is ignored due to a lack of a +// deterministic method to predict the time that an event will be dispatched. func entryDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // entryDiffOptsWithSort is like entryDiffOpts but also includes an option to // sort slices of entries for cases where ordering must be ignored. func entryDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + })) } func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - return &neighborCache{ + neigh := &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, nudDisp: nudDisp, }, - id: 1, + id: 1, + stats: makeNICStats(), }, state: NewNUDState(config, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } + neigh.nic.neigh = neigh + return neigh } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -128,9 +128,8 @@ func newTestEntryStore() *testEntryStore { linkAddr := toLinkAddress(i) store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LocalAddr: testEntryLocalAddr, - LinkAddr: linkAddr, + Addr: addr, + LinkAddr: linkAddr, } } return store @@ -195,10 +194,10 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(addr) + r.fakeRequest(targetAddr) }) // Execute post address resolution action, if available. @@ -294,9 +293,8 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -305,15 +303,19 @@ func TestNeighborCacheEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -324,8 +326,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,9 +356,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -365,15 +367,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -391,9 +397,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -404,8 +412,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -452,8 +460,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -470,23 +478,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, testEntryEventInfo{ EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }) c.nudDisp.mu.Lock() @@ -508,10 +522,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -564,24 +577,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -600,9 +616,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -640,9 +658,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -682,9 +702,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -703,9 +725,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -740,9 +764,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -760,9 +786,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -800,24 +828,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -836,16 +867,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -861,10 +896,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, }, }, } @@ -896,12 +930,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } clock.Advance(typicalLatency) @@ -913,7 +947,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { id, ok := s.Fetch(false /* block */) if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) } if id != wakerID { t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) @@ -923,15 +957,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -964,12 +1002,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } // Remove the waker before the neighbor cache has the opportunity to send a @@ -991,15 +1029,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1028,10 +1070,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) @@ -1041,9 +1082,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -1058,10 +1101,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, }, }, } @@ -1089,9 +1131,8 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1099,15 +1140,19 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1126,9 +1171,11 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1149,16 +1196,20 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestRemoved, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1185,24 +1236,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1220,9 +1274,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1274,29 +1330,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1312,9 +1372,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := neighborCacheSize; i < store.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1322,15 +1381,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1342,22 +1401,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1374,10 +1439,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { - Addr: frequentlyUsedEntry.Addr, - LocalAddr: frequentlyUsedEntry.LocalAddr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, + Addr: frequentlyUsedEntry.Addr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, }, } @@ -1387,10 +1451,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1430,9 +1493,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } @@ -1456,10 +1518,9 @@ func TestNeighborCacheConcurrent(t *testing.T) { t.Errorf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1488,37 +1549,36 @@ func TestNeighborCacheReplace(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) } if t.Failed() { t.FailNow() } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1542,37 +1602,35 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1601,35 +1659,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) - got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } // Verify that address resolution for an unknown address returns ErrNoLinkAddress before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } maxAttempts := neigh.config().MaxUnicastProbes @@ -1659,13 +1716,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } } @@ -1683,18 +1740,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) { delay: typicalLatency, } - got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ - Addr: testEntryBroadcastAddr, - LocalAddr: testEntryLocalAddr, - LinkAddr: testEntryBroadcastLinkAddr, - State: Static, + Addr: testEntryBroadcastAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1719,9 +1775,9 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh != nil { <-doneCh diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index be61a21af..32399b4f5 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -24,13 +24,18 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +const ( + // immediateDuration is a duration of zero for scheduling work that needs to + // be done immediately but asynchronously to avoid deadlock. + immediateDuration time.Duration = 0 +) + // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LocalAddr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAtNanos int64 } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -106,35 +111,35 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { return &neighborEntry{ nic: nic, linkRes: linkRes, nudState: nudState, neigh: NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - State: Unknown, + Addr: remoteAddr, + State: Unknown, }, } } -// newStaticNeighborEntry creates a neighbor cache entry starting at the Static -// state. The entry can only transition out of Static by directly calling -// `setStateLocked`. +// newStaticNeighborEntry creates a neighbor cache entry starting at the +// Static state. The entry can only transition out of Static by directly +// calling `setStateLocked`. func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + entry := NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + } if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) } return &neighborEntry{ nic: nic, nudState: state, - neigh: NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - }, + neigh: entry, } } @@ -165,17 +170,17 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. -func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborAdded(e.nic.id, e.neigh) } } // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. -func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborChanged(e.nic.id, e.neigh) } } @@ -183,7 +188,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { // has been removed. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } @@ -201,68 +206,24 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAt = time.Now() + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { case Incomplete: - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } - - sendMulticastProbe() + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Probe) e.setStateLocked(Probe) + e.dispatchChangeEventLocked() }) e.job.Schedule(config.DelayFirstProbeTime) @@ -277,28 +238,27 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } - sendUnicastProbe() + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(immediateDuration) case Failed: e.notifyWakersLocked() - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() { e.nic.neigh.removeEntryLocked(e) }) e.job.Schedule(config.UnreachableTime) @@ -315,19 +275,82 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. -func (e *neighborEntry) handlePacketQueuedLocked() { +func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Unknown: - e.dispatchAddEventLocked(Incomplete) - e.setStateLocked(Incomplete) + e.neigh.State = Incomplete + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + + e.dispatchAddEventLocked() + + config := e.nudState.Config() + + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is the + // same as one of the addresses assigned to the outgoing interface, that + // address SHOULD be placed in the IP Source Address of the outgoing + // solicitation. + // + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(immediateDuration) case Stale: - e.dispatchChangeEventLocked(Delay) e.setStateLocked(Delay) + e.dispatchChangeEventLocked() - case Incomplete, Reachable, Delay, Probe, Static, Failed: + case Incomplete, Reachable, Delay, Probe, Static: // Do nothing - + case Failed: + e.nic.stats.Neighbor.FailedEntryLookups.Increment() default: panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } @@ -345,21 +368,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { switch e.neigh.State { case Unknown, Incomplete, Failed: e.neigh.LinkAddr = remoteLinkAddr - e.dispatchAddEventLocked(Stale) e.setStateLocked(Stale) e.notifyWakersLocked() + e.dispatchAddEventLocked() case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } case Stale: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) + e.dispatchChangeEventLocked() } case Static: @@ -393,12 +416,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.neigh.LinkAddr = linkAddr if flags.Solicited { - e.dispatchChangeEventLocked(Reachable) e.setStateLocked(Reachable) } else { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) } + e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter e.notifyWakersLocked() @@ -411,8 +433,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } break } @@ -421,23 +443,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if !flags.Solicited { if e.neigh.State != Stale { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } else { // Notify the LinkAddr change, even though NUD state hasn't changed. - e.dispatchChangeEventLocked(e.neigh.State) + e.dispatchChangeEventLocked() } break } } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - } + wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyWakersLocked() + if !wasReachable { + e.dispatchChangeEventLocked() + } } if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { @@ -475,11 +498,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - // Set state to Reachable again to refresh timers. - } + wasReachable := e.neigh.State == Reachable + // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) + if !wasReachable { + e.dispatchChangeEventLocked() + } case Unknown, Incomplete, Failed, Static: // Do nothing @@ -488,3 +512,23 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() { panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } } + +// doubleLock combines two locks into one while maintaining lock ordering. +// +// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed +// neighbor is allowed. +type doubleLock struct { + first, second sync.Locker +} + +// Lock locks both locks in order: first then second. +func (l *doubleLock) Lock() { + l.first.Lock() + l.second.Lock() +} + +// Unlock unlocks both locks in reverse order: second then first. +func (l *doubleLock) Unlock() { + l.second.Unlock() + l.first.Unlock() +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 3ee2a3b31..c497d3932 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -47,24 +47,27 @@ const ( entryTestNetDefaultMTU = 65536 ) +// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { + clock.Advance(immediateDuration) +} + // eventDiffOpts are the options passed to cmp.Diff to compare entry events. -// The UpdatedAt field is ignored due to a lack of a deterministic method to -// predict the time that an event will be dispatched. +// The UpdatedAtNanos field is ignored due to a lack of a deterministic method +// to predict the time that an event will be dispatched. func eventDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // eventDiffOptsWithSort is like eventDiffOpts but also includes an option to // sort slices of events for cases where ordering must be ignored. func eventDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + })) } // The following unit tests exercise every state transition and verify its @@ -86,7 +89,7 @@ func eventDiffOptsWithSort() []cmp.Option { // | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Stale | Stale | Override confirmation | Update LinkAddr | Changed | // | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | -// | Stale | Delay | Packet sent | | Changed | +// | Stale | Delay | Packet queued | | Changed | // | Delay | Reachable | Upper-layer confirmation | | Changed | // | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | // | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | @@ -98,6 +101,7 @@ func eventDiffOptsWithSort() []cmp.Option { // | Probe | Stale | Probe or confirmation w/ different address | | Changed | // | Probe | Probe | Retransmit timer expired | Send probe | Changed | // | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Failed | Failed | Packet queued | | | // | Failed | | Unreachability timer expired | Delete entry | | type testEntryEventType uint8 @@ -125,14 +129,11 @@ func (t testEntryEventType) String() string { type testEntryEventInfo struct { EventType testEntryEventType NICID tcpip.NICID - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Entry NeighborEntry } func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) } // testNUDDispatcher implements NUDDispatcher to validate the dispatching of @@ -150,36 +151,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { d.events = append(d.events, e) } -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestAdded, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestChanged, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestRemoved, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } @@ -202,9 +194,9 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { p := entryTestProbeInfo{ - RemoteAddress: addr, + RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, LocalAddress: localAddr, } @@ -237,6 +229,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock: clock, nudDisp: &disp, }, + stats: makeNICStats(), } nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), @@ -245,7 +238,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ @@ -323,15 +316,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { func TestEntryUnknownToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -350,9 +344,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } { @@ -367,7 +363,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { func TestEntryUnknownToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) @@ -377,6 +373,7 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Unlock() // No probes should have been sent. + runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) linkRes.mu.Unlock() @@ -388,9 +385,11 @@ func TestEntryUnknownToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -406,11 +405,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } - updatedAt := e.neigh.UpdatedAt + updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() clock.Advance(c.RetransmitTimer) @@ -437,7 +436,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want { t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -468,16 +467,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -487,7 +490,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) } e.mu.Unlock() @@ -495,23 +498,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { func TestEntryIncompleteToReachable(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -526,20 +522,35 @@ func TestEntryIncompleteToReachable(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -555,7 +566,7 @@ func TestEntryIncompleteToReachable(t *testing.T) { // to Reachable. func TestEntryAddsAndClearsWakers(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) w := sleep.Waker{} s := sleep.Sleeper{} @@ -563,7 +574,25 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { defer s.Done() e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() if got := e.wakers; got != nil { t.Errorf("got e.wakers = %v, want = nil", got) } @@ -587,34 +616,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -626,26 +645,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.isRouter, true; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -659,20 +668,38 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { } linkRes.mu.Unlock() + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -684,23 +711,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { func TestEntryIncompleteToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -715,20 +735,35 @@ func TestEntryIncompleteToStale(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -744,7 +779,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -783,16 +818,20 @@ func TestEntryIncompleteToFailed(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -817,12 +856,30 @@ func (*testLocker) Unlock() {} func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -848,34 +905,24 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -893,27 +940,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -928,20 +961,42 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -961,17 +1016,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -986,29 +1034,46 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1026,24 +1091,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1058,27 +1112,48 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1086,38 +1161,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1132,27 +1186,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1160,38 +1239,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1206,27 +1264,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1234,37 +1317,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1279,20 +1342,42 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1304,31 +1389,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1343,27 +1410,55 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1375,10 +1470,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } e.mu.Lock() - e.handlePacketQueuedLocked() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1400,41 +1513,33 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1446,31 +1551,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1485,27 +1572,55 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1517,27 +1632,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1552,27 +1653,51 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1584,24 +1709,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { func TestEntryStaleToDelay(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1616,27 +1730,48 @@ func TestEntryStaleToDelay(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1656,22 +1791,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleUpperLevelConfirmationLocked() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1686,43 +1809,68 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleUpperLevelConfirmationLocked() + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1743,29 +1891,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1780,43 +1909,75 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1837,13 +1998,31 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if e.neigh.State != Delay { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } @@ -1860,57 +2039,52 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1922,32 +2096,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1962,27 +2117,56 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1994,25 +2178,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2027,34 +2199,58 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2066,29 +2262,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2103,34 +2283,62 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2145,69 +2353,91 @@ func TestEntryDelayToProbe(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2228,36 +2458,50 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2274,37 +2518,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2312,12 +2566,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { @@ -2325,36 +2573,50 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2375,37 +2637,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2413,12 +2685,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { @@ -2426,36 +2692,51 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2479,30 +2760,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2529,17 +2818,14 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - wantProbes := []entryTestProbeInfo{ - // Probe caused by the Delay-to-Probe transition { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2567,42 +2853,51 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2622,36 +2917,50 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2672,49 +2981,60 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2734,36 +3054,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2781,49 +3115,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2843,36 +3188,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2890,49 +3249,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2946,87 +3316,238 @@ func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 + c.DelayFirstProbeTime = c.RetransmitTimer e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.Advance(waitFor) + // Observe each probe sent while in the Probe state. + for i := uint32(0); i < c.MaxUnicastProbes; i++ { + clock.Advance(c.RetransmitTimer) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + } - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.mu.Unlock() + } + + // Wait for the last probe to expire, causing a transition to Failed. + clock.Advance(c.RetransmitTimer) + e.mu.Lock() + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, + EventType: entryTestAdded, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, - // The next three probe are caused by the Delay-to-Probe transition. { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, + EventType: entryTestChanged, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryFailedToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // Verify the cache contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { + t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + } + + // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in + // their expected state. + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + clock.Advance(waitFor) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -3035,11 +3556,23 @@ func TestEntryProbeToFailed(t *testing.T) { } nudDisp.mu.Unlock() - e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + failedLookups := e.nic.stats.Neighbor.FailedEntryLookups + if got := failedLookups.Value(); got != 0 { + t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got) } + + e.mu.Lock() + // Verify queuing a packet to the entry immediately fails. + e.handlePacketQueuedLocked(entryTestAddr2) + state := e.neigh.State e.mu.Unlock() + if state != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", state, Failed) + } + + if got := failedLookups.Value(); got != 1 { + t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got) + } } func TestEntryFailedGetsDeleted(t *testing.T) { @@ -3054,84 +3587,106 @@ func TestEntryFailedGetsDeleted(t *testing.T) { } e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime clock.Advance(waitFor) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The next three probe are sent in Probe. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dcd4319bf..5d037a27e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -54,18 +54,20 @@ type NIC struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + // packetEPs is protected by mu, but the contained packetEndpointList are + // not. + packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } } -// NICStats includes transmitted and received stats. +// NICStats hold statistics for a NIC. type NICStats struct { Tx DirectionStats Rx DirectionStats DisabledRx DirectionStats + + Neighbor NeighborStats } func makeNICStats() NICStats { @@ -80,6 +82,39 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } +type packetEndpointList struct { + mu sync.RWMutex + + // eps is protected by mu, but the contained PacketEndpoint values are not. + eps []PacketEndpoint +} + +func (p *packetEndpointList) add(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.eps = append(p.eps, ep) +} + +func (p *packetEndpointList) remove(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + for i, epOther := range p.eps { + if epOther == ep { + p.eps = append(p.eps[:i], p.eps[i+1:]...) + break + } + } +} + +// forEach calls fn with each endpoints in p while holding the read lock on p. +func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, ep := range p.eps { + fn(ep) + } +} + // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -100,7 +135,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -123,11 +158,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil + nic.mu.packetEPs[netNum] = new(packetEndpointList) nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } @@ -170,7 +205,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -182,6 +217,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. @@ -232,7 +271,8 @@ func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Unlock() } -func (n *NIC) isPromiscuousMode() bool { +// Promiscuous implements NetworkInterface. +func (n *NIC) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -264,7 +304,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } return err @@ -273,6 +313,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb return n.writePacket(r, gso, protocol, pkt) } +// WritePacketToRemote implements NetworkInterface. +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + r := Route{ + NetProto: protocol, + } + r.ResolveWith(remoteLinkAddr) + return n.writePacket(&r, gso, protocol, pkt) +} + func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -311,16 +360,21 @@ func (n *NIC) setSpoofing(enable bool) { // primaryAddress returns an address that can be used to communicate with // remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { - n.mu.RLock() - spoofing := n.mu.spoofing - n.mu.RUnlock() - ep, ok := n.networkEndpoints[protocol] if !ok { return nil } - return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + + return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } type getAddressBehaviour int @@ -339,6 +393,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } +func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + return true + } + + return false +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) @@ -369,11 +433,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { - if ep, ok := n.networkEndpoints[protocol]; ok { - return ep.AcquireAssignedAddress(address, createTemp, peb) + ep, ok := n.networkEndpoints[protocol] + if !ok { + return nil } - return nil + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb) } // addAddress adds a new address to n, so that it starts accepting packets @@ -384,7 +454,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return tcpip.ErrUnknownProtocol } - addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -397,7 +472,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PermanentAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PermanentAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -408,7 +488,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PrimaryAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PrimaryAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -426,13 +511,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } - return ep.MainAddress() + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.AddressWithPrefix{} + } + + return addressableEndpoint.MainAddress() } // removeAddress removes an address from n. func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { for _, ep := range n.networkEndpoints { - if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { continue } else { return err @@ -505,8 +600,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -522,11 +616,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. @@ -545,13 +635,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return false } -func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) - defer r.Release() - r.RemoteLinkAddress = remotelinkAddr - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) -} - // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the link endpoint. @@ -573,7 +656,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) - netProto, ok := n.stack.networkProtocols[protocol] + networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -585,23 +668,29 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if local == "" { local = n.LinkEndpoint.LinkAddress() } + pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? - packetEPs := n.mu.packetEPs[protocol] - // Add any other packet type sockets that may be listening for all protocols. - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) + protoEPs := n.mu.packetEPs[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + // Deliver to interested packet endpoints without holding NIC lock. + deliverPacketEPs := func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketHost ep.HandlePacket(n.id, local, protocol, p) } - - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() + if protoEPs != nil { + protoEPs.forEach(deliverPacketEPs) + } + if anyEPs != nil { + anyEPs.forEach(deliverPacketEPs) } // Parse headers. + netProto := n.stack.NetworkProtocolInstance(protocol) transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { // The packet is too small to contain a network header. @@ -616,9 +705,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.IsLoopback() { + src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) if r := n.getAddress(protocol, src); r != nil { r.DecRef() @@ -631,78 +719,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - // Loopback traffic skips the prerouting chain. - if !n.IsLoopback() { - // iptables filtering. - ipt := n.stack.IPTables() - address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { - // iptables is telling us to drop the packet. - n.stack.stats.IP.IPTablesPreroutingDropped.Increment() - return - } - } - - if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { - n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) - return - } - - // This NIC doesn't care about the packet. Find a NIC that cares about the - // packet and forward it to the NIC. - // - // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding(protocol) { - r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) - if err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - // Found a NIC. - n := r.nic - if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { - if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() - r.RemoteLinkAddress = remote - r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) - addressEndpoint.DecRef() - r.Release() - return - } - - addressEndpoint.DecRef() - } - - // n doesn't have a destination endpoint. - // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - // We need to do a deep copy of the IP packet because WritePacket (and - // friends) take ownership of the packet buffer, but we do not own it. - Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } - - r.Release() - return - } - - // If a packet socket handled the packet, don't treat it as invalid. - if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } + networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. @@ -711,21 +728,22 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + eps.forEach(func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) - } + }) } // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -737,7 +755,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, pkt) + n.stack.demux.deliverRawPacket(protocol, pkt) // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. @@ -766,14 +784,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN return TransportPacketHandled } - id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, pkt, id) { + netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] + if !ok { + panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers())) + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + id := TransportEndpointID{ + LocalPort: dstPort, + LocalAddress: dst, + RemotePort: srcPort, + RemoteAddress: src, + } + if n.stack.demux.deliverPacket(protocol, pkt, id) { return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, pkt) { + if state.defaultHandler(id, pkt) { return TransportPacketHandled } } @@ -781,7 +810,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet so // give the protocol specific error handler a chance to handle it. // If it doesn't handle it then we should do so. - switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() return TransportPacketHandled @@ -862,7 +891,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa if !ok { return tcpip.ErrNotSupported } - n.mu.packetEPs[netProto] = append(eps, ep) + eps.add(ep) return nil } @@ -875,17 +904,11 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep if !ok { return } - - for i, epOther := range eps { - if epOther == ep { - n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) - return - } - } + eps.remove(ep) } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address) +// packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 97a96af62..5b5c58afb 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip } // HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { -} +func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} // Close implements NetworkEndpoint.Close. func (e *testIPv6Endpoint) Close() { @@ -169,7 +168,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index e1ec15487..ab629b3a4 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -129,7 +129,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborAdded(tcpip.NICID, NeighborEntry) // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) // neighbor table changes state and/or link address. @@ -138,7 +138,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborChanged(tcpip.NICID, NeighborEntry) // OnNeighborRemoved will be called when an entry is removed from a NIC's // (with ID nicID) neighbor table. @@ -147,7 +147,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborRemoved(tcpip.NICID, NeighborEntry) } // ReachabilityConfirmationFlags describes the flags used within a reachability @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7f54a6de8..664cc6fa0 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -112,6 +112,16 @@ type PacketBuffer struct { // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType + + // NICID is the ID of the interface the network packet was received at. + NICID tcpip.NICID + + // RXTransportChecksumValidated indicates that transport checksum verification + // may be safely skipped. + RXTransportChecksumValidated bool + + // NetworkPacketInfo holds an incoming packet's network-layer information. + NetworkPacketInfo NetworkPacketInfo } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // Clone should be called in such cases so that no modifications is done to // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, - TransportProtocolNumber: pk.TransportProtocolNumber, + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, + PktType: pk.PktType, + NICID: pk.NICID, + RXTransportChecksumValidated: pk.RXTransportChecksumValidated, + NetworkPacketInfo: pk.NetworkPacketInfo, } - return newPk +} + +// SourceLinkAddress returns the source link address of the packet. +func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { + link := pk.LinkHeader().View() + + if link.IsEmpty() { + return "" + } + + return header.Ethernet(link).SourceAddress() } // Network returns the network header as a header.Network. @@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network { } } +// CloneToInbound makes a shallow copy of the packet buffer to be used as an +// inbound packet. +// +// See PacketBuffer.Data for details about how a packet buffer holds an inbound +// packet. +func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { + return NewPacketBuffer(PacketBufferOptions{ + Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), + }) +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index f838eda8d..5d364a2b0 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } else if _, err := p.route.Resolve(nil); err != nil { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..b334e27c4 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -63,17 +63,24 @@ const ( ControlUnknown ) +// NetworkPacketInfo holds information about a network layer packet. +type NetworkPacketInfo struct { + // LocalAddressBroadcast is true if the packet's local address is a broadcast + // address. + LocalAddressBroadcast bool +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { // UniqueID returns an unique ID for this transport endpoint. UniqueID() uint64 - // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. It sets pkt.TransportHeader. + // HandlePacket is called by the stack when new packets arrive to this + // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(TransportEndpointID, *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. @@ -105,8 +112,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(*PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -127,7 +134,7 @@ type PacketEndpoint interface { HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } -// UnknownDestinationPacketDisposition enumerates the possible return vaues from +// UnknownDestinationPacketDisposition enumerates the possible return values from // HandleUnknownDestinationPacket(). type UnknownDestinationPacketDisposition int @@ -172,9 +179,9 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // HandleUnknownDestinationPacket takes ownership of the packet if it handles // the issue. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition + HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -227,8 +234,8 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition + // DeliverTransportPacket takes ownership of the packet. + DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -270,15 +277,11 @@ type NetworkHeaderParams struct { // An endpoint is considered to support group addressing when one or more // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { - // JoinGroup joins the spcified group. - // - // Returns true if the group was newly joined. - JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + // JoinGroup joins the specified group. + JoinGroup(group tcpip.Address) *tcpip.Error // LeaveGroup attempts to leave the specified group. - // - // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. - LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + LeaveGroup(group tcpip.Address) *tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -329,6 +332,9 @@ type AssignableAddressEndpoint interface { // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix + // Subnet returns the subnet of the endpoint's address. + Subnet() tcpip.Subnet + // IsAssigned returns whether or not the endpoint is considered bound // to its NetworkEndpoint. IsAssigned(allowExpired bool) bool @@ -364,7 +370,7 @@ type AddressEndpoint interface { SetDeprecated(bool) } -// AddressKind is the kind of of an address. +// AddressKind is the kind of an address. // // See the values of AddressKind for more details. type AddressKind int @@ -490,13 +496,17 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // Promiscuous returns true if the interface is in promiscuous mode. + Promiscuous() bool + + // WritePacketToRemote writes the packet to the given remote link address. + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { - AddressableEndpoint - // Enable enables the endpoint. // // Must only be called when the stack is in a state that allows the endpoint @@ -544,7 +554,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -712,10 +722,6 @@ type LinkEndpoint interface { // endpoint. Capabilities() LinkEndpointCapabilities - // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. It takes ownership of vv. - WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error - // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // @@ -764,13 +770,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts - // the request on the local network if remoteLinkAddr is the zero value. The - // request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the link address of the target + // address. The request is broadcasted on the local network if a remote link + // address is not provided. // - // A valid response will cause the discovery protocol's network - // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error + // The request is sent from the passed network interface. If the interface + // local address is unspecified, any interface local address may be used. + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index b76e2d37b..de5fe6ffe 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -15,20 +15,25 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) // Route represents a route through the networking stack to a given destination. +// +// It is safe to call Route's methods from multiple goroutines. +// +// The exported fields are immutable. +// +// TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { // RemoteAddress is the final destination of the route. RemoteAddress tcpip.Address - // RemoteLinkAddress is the link-layer (MAC) address of the - // final destination of the route. - RemoteLinkAddress tcpip.LinkAddress - // LocalAddress is the local address where the route starts. LocalAddress tcpip.Address @@ -45,11 +50,24 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - // nic is the NIC the route goes through. - nic *NIC + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC - // addressEndpoint is the local address this route is associated with. - addressEndpoint AssignableAddressEndpoint + mu struct { + sync.RWMutex + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint + + // remoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + remoteLinkAddress tcpip.LinkAddress + } + + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC // linkCache is set if link address resolution is enabled for this protocol on // the route's NIC. @@ -60,51 +78,139 @@ type Route struct { linkRes LinkAddressResolver } +// constructAndValidateRoute validates and initializes a route. It takes +// ownership of the provided local address. +// +// Returns an empty route if validation fails. +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { + if len(localAddr) == 0 { + localAddr = addressEndpoint.AddressWithPrefix().Address + } + + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { + addressEndpoint.DecRef() + return nil + } + + // If no remote address is provided, use the local address. + if len(remoteAddr) == 0 { + remoteAddr = localAddr + } + + r := makeRoute( + netProto, + localAddr, + remoteAddr, + outgoingNIC, + localAddressNIC, + addressEndpoint, + handleLocal, + multicastLoop, + ) + + // If the route requires us to send a packet through some gateway, do not + // broadcast it. + if len(gateway) > 0 { + r.NextHop = gateway + } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.ResolveWith(header.EthernetBroadcastAddress) + } + + return r +} + // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { + if localAddressNIC.stack != outgoingNIC.stack { + panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) + } + + if len(localAddr) == 0 { + localAddr = localAddressEndpoint.AddressWithPrefix().Address + } + loop := PacketOut - if handleLocal && localAddr != "" && remoteAddr == localAddr { - loop = PacketLoop - } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { - loop |= PacketLoop - } else if remoteAddr == header.IPv4Broadcast { - loop |= PacketLoop + + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if !outgoingNIC.IsLoopback() { + if handleLocal && localAddr != "" && remoteAddr == localAddr { + loop = PacketLoop + } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { + loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop + } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + loop |= PacketLoop + } } - r := Route{ + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { + r := &Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), RemoteAddress: remoteAddr, - addressEndpoint: addressEndpoint, - nic: nic, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + r.mu.Lock() + r.mu.localAddressEndpoint = localAddressEndpoint + r.mu.Unlock() + + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = r.outgoingNIC.stack } } return r } +// makeLocalRoute initializes a new local route. It takes ownership of the +// provided AssignableAddressEndpoint. +// +// A local route is a route to a destination that is local to the stack. +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { + loop := PacketLoop + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if outgoingNIC.IsLoopback() { + loop = PacketOut + } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in +// the route. +func (r *Route) RemoteLinkAddress() tcpip.LinkAddress { + r.mu.RLock() + defer r.mu.RUnlock() + return r.mu.remoteLinkAddress +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.nic.ID() + return r.outgoingNIC.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.nic.stack.Stats() + return r.outgoingNIC.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -113,14 +219,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen) } -// Capabilities returns the link-layer capabilities of the route. -func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() +// RequiresTXTransportChecksum returns false if the route does not require +// transport checksums to be populated. +func (r *Route) RequiresTXTransportChecksum() bool { + if r.local() { + return false + } + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0 +} + +// HasSoftwareGSOCapability returns true if the route supports software GSO. +func (r *Route) HasSoftwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 +} + +// HasHardwareGSOCapability returns true if the route supports hardware GSO. +func (r *Route) HasHardwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 +} + +// HasSaveRestoreCapability returns true if the route supports save/restore. +func (r *Route) HasSaveRestoreCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0 +} + +// HasDisconncetOkCapability returns true if the route supports disconnecting. +func (r *Route) HasDisconncetOkCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0 } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -129,7 +259,9 @@ func (r *Route) GSOMaxSize() uint32 { // ResolveWith immediately resolves a route with the specified remote link // address. func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr + r.mu.Lock() + defer r.mu.Unlock() + r.mu.remoteLinkAddress = addr } // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in @@ -142,7 +274,10 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { // // The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { - if !r.IsResolutionRequired() { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. return nil, nil @@ -152,26 +287,33 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if nextAddr == "" { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { - r.RemoteLinkAddress = r.LocalLinkAddress + r.mu.remoteLinkAddress = r.LocalLinkAddress return nil, nil } nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + // If specified, the local address used for link address resolution must be an + // address on the outgoing interface. + var linkAddressResolutionRequestLocalAddr tcpip.Address + if r.localAddressNIC == r.outgoingNIC { + linkAddressResolutionRequestLocalAddr = r.LocalAddress + } + + if neigh := r.outgoingNIC.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) if err != nil { return ch, err } - r.RemoteLinkAddress = entry.LinkAddr + r.mu.remoteLinkAddress = entry.LinkAddr return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) if err != nil { return ch, err } - r.RemoteLinkAddress = linkAddr + r.mu.remoteLinkAddress = linkAddr return nil, nil } @@ -182,100 +324,146 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { + if neigh := r.outgoingNIC.neigh; neigh != nil { neigh.removeWaker(nextAddr, waker) return } - r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) +} + +// local returns true if the route is a local route. +func (r *Route) local() bool { + return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() } // IsResolutionRequired returns true if Resolve() must be called to resolve -// the link address before the this route can be written to. +// the link address before the route can be written to. // -// The NIC r uses must not be locked. +// The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.nic.neigh != nil { - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + r.mu.RLock() + defer r.mu.RUnlock() + return r.isResolutionRequiredRLocked() +} + +func (r *Route) isResolutionRequiredRLocked() bool { + if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { + return false } - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" + + return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil +} + +func (r *Route) isValidForOutgoing() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isValidForOutgoingRLocked() +} + +func (r *Route) isValidForOutgoingRLocked() bool { + if !r.outgoingNIC.Enabled() { + return false + } + + localAddressEndpoint := r.mu.localAddressEndpoint + if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) { + return false + } + + // If the source NIC and outgoing NIC are different, make sure the stack has + // forwarding enabled, or the packet will be handled locally. + if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) { + return false + } + + return true } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.nic.getNetworkEndpoint(r.NetProto).MTU() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.addressEndpoint != nil { - r.addressEndpoint.DecRef() - r.addressEndpoint = nil + r.mu.Lock() + defer r.mu.Unlock() + + if r.mu.localAddressEndpoint != nil { + r.mu.localAddressEndpoint.DecRef() + r.mu.localAddressEndpoint = nil } } // Clone clones the route. -func (r *Route) Clone() Route { - if r.addressEndpoint != nil { - _ = r.addressEndpoint.IncRef() +func (r *Route) Clone() *Route { + r.mu.RLock() + defer r.mu.RUnlock() + + newRoute := &Route{ + RemoteAddress: r.RemoteAddress, + LocalAddress: r.LocalAddress, + LocalLinkAddress: r.LocalLinkAddress, + NextHop: r.NextHop, + NetProto: r.NetProto, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, } - return *r -} -// MakeLoopedRoute duplicates the given route with special handling for routes -// used for sending multicast or broadcast packets. In those cases the -// multicast/broadcast address is the remote address when sending out, but for -// incoming (looped) packets it becomes the local address. Similarly, the local -// interface address that was the local address going out becomes the remote -// address coming in. This is different to unicast routes where local and -// remote addresses remain the same as they identify location (local vs remote) -// not direction (source vs destination). -func (r *Route) MakeLoopedRoute() Route { - l := r.Clone() - if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { - l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress - l.RemoteLinkAddress = l.LocalLinkAddress + newRoute.mu.Lock() + defer newRoute.mu.Unlock() + newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint + if newRoute.mu.localAddressEndpoint != nil { + if !newRoute.mu.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress)) + } } - return l + newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress + + return newRoute } // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.nic.stack + return r.outgoingNIC.stack } func (r *Route) isV4Broadcast(addr tcpip.Address) bool { @@ -283,7 +471,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + r.mu.RLock() + localAddressEndpoint := r.mu.localAddressEndpoint + r.mu.RUnlock() + if localAddressEndpoint == nil { + return false + } + + subnet := localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -293,26 +488,3 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } - -// IsInboundBroadcast returns true if the route is for an inbound broadcast -// packet. -func (r *Route) IsInboundBroadcast() bool { - // Only IPv4 has a notion of broadcast. - return r.isV4Broadcast(r.LocalAddress) -} - -// ReverseRoute returns new route with given source and destination address. -func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { - return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - addressEndpoint: r.addressEndpoint, - nic: r.nic, - linkCache: r.linkCache, - linkRes: r.linkRes, - } -} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..dc4f5b3e7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "fmt" mathrand "math/rand" "sync/atomic" "time" @@ -52,7 +53,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -81,6 +82,7 @@ type TCPRACKState struct { FACK seqnum.Value RTT time.Duration Reord bool + DSACKSeen bool } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -518,6 +520,10 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // IPTables are the initial iptables rules. If nil, iptables will allow + // all traffic. + IPTables *IPTables } // TransportEndpointInfo holds useful information about a transport endpoint @@ -620,6 +626,10 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } + if opts.IPTables == nil { + opts.IPTables = DefaultTables() + } + opts.NUDConfigs.resetInvalidFields() s := &Stack{ @@ -633,7 +643,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - tables: DefaultTables(), + tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, @@ -751,7 +761,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -830,6 +840,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] @@ -1057,7 +1081,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. Running: nic.Enabled(), - Promiscuous: nic.isPromiscuousMode(), + Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ @@ -1094,6 +1118,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } +// AddAddressWithPrefix is the same as AddAddress, but allows you to specify +// the address prefix. +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { + ap := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: addr, + } + return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) +} + // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { @@ -1180,54 +1214,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } +// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route +// from the specified NIC. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { + localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) + if localAddressEndpoint == nil { + return nil + } + + var outgoingNIC *NIC + // Prefer a local route to the same interface as the local address. + if localAddressNIC.hasAddress(netProto, remoteAddr) { + outgoingNIC = localAddressNIC + } + + // If the remote address isn't owned by the local address's NIC, check all + // NICs. + if outgoingNIC == nil { + for _, nic := range s.nics { + if nic.hasAddress(netProto, remoteAddr) { + outgoingNIC = nic + break + } + } + } + + // If the remote address is not owned by the stack, we can't return a local + // route. + if outgoingNIC == nil { + localAddressEndpoint.DecRef() + return nil + } + + r := makeLocalRoute( + netProto, + localAddr, + remoteAddr, + outgoingNIC, + localAddressNIC, + localAddressEndpoint, + ) + + if r.IsOutboundBroadcast() { + r.Release() + return nil + } + + return r +} + +// findLocalRouteRLocked returns a local route. +// +// A local route is a route to some remote address which the stack owns. That +// is, a local route is a route where packets never have to leave the stack. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { + if len(localAddr) == 0 { + localAddr = remoteAddr + } + + if localAddressNICID == 0 { + for _, localAddressNIC := range s.nics { + if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil { + return r + } + } + + return nil + } + + if localAddressNIC, ok := s.nics[localAddressNICID]; ok { + return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) + } + + return nil +} + // FindRoute creates a route to the given destination address, leaving through -// the given nic and local address (if provided). -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { +// the given NIC and local address (if provided). +// +// If a NIC is not specified, the returned route will leave through the same +// NIC as the NIC that has the local address assigned when forwarding is +// disabled. If forwarding is enabled and the NIC is unspecified, the route may +// leave through any interface unless the route is link-local. +// +// If no local address is provided, the stack will select a local address. If no +// remote address is provided, the stack wil use a remote address equal to the +// local address. +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() + isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) + needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) + + if s.handleLocal && !isMulticast && !isLocalBroadcast { + if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil { + return r, nil + } + } + + // If the interface is specified and we do not need a route, return a route + // through the interface if the interface is valid and enabled. if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.Enabled() { if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil + return makeRoute( + netProto, + localAddr, + remoteAddr, + nic, /* outboundNIC */ + nic, /* localAddressNIC*/ + addressEndpoint, + s.handleLocal, + multicastLoop, + ), nil } } - } else { - for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { - continue + + if isLoopback { + return nil, tcpip.ErrBadLocalAddress + } + return nil, tcpip.ErrNetworkUnreachable + } + + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + + // Find a route to the remote with the route table. + var chosenRoute tcpip.Route + for _, route := range s.routeTable { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } + + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r, nil } - if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if len(remoteAddr) == 0 { - // If no remote address was provided, then the route - // provided will refer to the link local address. - remoteAddr = addressEndpoint.AddressWithPrefix().Address - } + } + + // If the stack has forwarding enabled and we haven't found a valid route to + // the remote address yet, keep track of the first valid route. We keep + // iterating because we prefer routes that let us use a local address that + // is assigned to the outgoing interface. There is no requirement to do this + // from any RFC but simply a choice made to better follow a strong host + // model which the netstack follows at the time of writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } + } + + if chosenRoute != (tcpip.Route{}) { + // At this point we know the stack has forwarding enabled since chosenRoute is + // only set when forwarding is enabled. + nic, ok := s.nics[chosenRoute.NIC] + if !ok { + // If the route's NIC was invalid, we should not have chosen the route. + panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC)) + } + + var gateway tcpip.Address + if needRoute { + gateway = chosenRoute.Gateway + } - r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) - if len(route.Gateway) > 0 { - if needRoute { - r.NextHop = route.Gateway - } - } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + // Use the specified NIC to get the local address endpoint. + if id != 0 { + if aNIC, ok := s.nics[id]; ok { + if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { + return r, nil } + } + } + + return nil, tcpip.ErrNoRoute + } + + if id == 0 { + // If an interface is not specified, try to find a NIC that holds the local + // address endpoint to construct a route. + for _, aNIC := range s.nics { + addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto) + if addressEndpoint == nil { + continue + } + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } } } - if !needRoute { - return Route{}, tcpip.ErrNetworkUnreachable + if needRoute { + return nil, tcpip.ErrNoRoute } - - return Route{}, tcpip.ErrNoRoute + if header.IsV6LoopbackAddress(remoteAddr) { + return nil, tcpip.ErrBadLocalAddress + } + return nil, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1323,7 +1528,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) } // Neighbors returns all IP to MAC address associations. @@ -1443,8 +1648,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { // FindTransportEndpoint finds an endpoint that most closely matches the provided // id. If no endpoint is found it returns nil. -func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { - return s.demux.findTransportEndpoint(netProto, transProto, id, r) +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, nicID) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1615,49 +1820,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip nic.unregisterPacketEndpoint(netProto, ep) } -// WritePacket writes data directly to the specified NIC. It adds an ethernet -// header based on the arguments. -func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { - s.mu.Lock() - nic, ok := s.nics[nicID] - s.mu.Unlock() - if !ok { - return tcpip.ErrUnknownDevice - } - - // Add our own fake ethernet header. - ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), - DstAddr: dst, - Type: netProto, - } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - vv := buffer.View(fakeHeader).ToVectorisedView() - vv.Append(payload) - - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { - return err - } - - return nil -} - -// WriteRawPacket writes data directly to the specified NIC without adding any -// headers. -func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { +// WritePacketToRemote writes a payload on the specified NIC using the provided +// network protocol and remote link address. +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice } - - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { - return err - } - - return nil + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(nic.MaxHeaderLength()), + Data: payload, + }) + return nic.WritePacketToRemote(remote, nil, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the @@ -1717,7 +1893,6 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - // TODO: notify network of subscription via igmp protocol. s.mu.RLock() defer s.mu.RUnlock() @@ -1896,3 +2071,111 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { return tcpip.NewJob(s.clock, l, f) } + +// ParseResult indicates the result of a parsing attempt. +type ParseResult int + +const ( + // ParsedOK indicates that a packet was successfully parsed. + ParsedOK ParseResult = iota + + // UnknownNetworkProtocol indicates that the network protocol is unknown. + UnknownNetworkProtocol + + // NetworkLayerParseError indicates that the network packet was not + // successfully parsed. + NetworkLayerParseError + + // UnknownTransportProtocol indicates that the transport protocol is unknown. + UnknownTransportProtocol + + // TransportLayerParseError indicates that the transport packet was not + // successfully parsed. + TransportLayerParseError +) + +// ParsePacketBuffer parses the provided packet buffer. +func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return UnknownNetworkProtocol + } + + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + return NetworkLayerParseError + } + if !hasTransportHdr { + return ParsedOK + } + + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + return ParsedOK + } + + pkt.TransportProtocolNumber = transProtoNum + // Parse the transport header if present. + state, ok := s.transportProtocols[transProtoNum] + if !ok { + return UnknownTransportProtocol + } + + if !state.proto.Parse(pkt) { + return TransportLayerParseError + } + + return ParsedOK +} + +// networkProtocolNumbers returns the network protocol numbers the stack is +// configured with. +func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { + protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols)) + for p := range s.networkProtocols { + protos = append(protos, p) + } + return protos +} + +func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + return false + } + + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + return subnet.IsBroadcast(addr) +} + +// IsSubnetBroadcast returns true if the provided address is a subnet-local +// broadcast address on the specified NIC and protocol. +// +// Returns false if the NIC is unknown or if the protocol is unknown or does +// not support addressing. +// +// If the NIC is not specified, the stack will check all NICs. +func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + if nicID != 0 { + nic, ok := s.nics[nicID] + if !ok { + return false + } + + return isSubnetBroadcastOnNIC(nic, protocol, addr) + } + + for _, nic := range s.nics { + if isSubnetBroadcastOnNIC(nic, protocol, addr) { + return true + } + } + + return false +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..457990945 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,12 +21,12 @@ import ( "bytes" "fmt" "math" + "net" "sort" "testing" "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -108,12 +108,21 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ + netHdr := pkt.NetworkHeader().View() + + dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + return + } + addressEndpoint.DecRef() + + f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { + if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -129,7 +138,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -151,12 +160,13 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + pkt.NetworkProtocolNumber = fakeNetNumber hdr[dstAddrOffset] = r.RemoteAddress[0] hdr[srcAddrOffset] = r.LocalAddress[0] hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - f.HandlePacket(r, pkt) + f.HandlePacket(pkt.Clone()) } if r.Loop&stack.PacketOut == 0 { return nil @@ -254,6 +264,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto if !ok { return 0, false, false } + pkt.NetworkProtocolNumber = fakeNetNumber return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -395,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -413,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En } } -func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() ep.Drain() if err := send(r, payload); err != nil { @@ -424,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer. } } -func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) @@ -1334,6 +1345,106 @@ func TestPromiscuousMode(t *testing.T) { testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } +// TestExternalSendWithHandleLocal tests that the stack creates a non-local +// route when spoofing or promiscuous mode are enabled. +// +// This test makes sure that packets are transmitted from the stack. +func TestExternalSendWithHandleLocal(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + + localAddr = tcpip.Address("\x01") + dstAddr = tcpip.Address("\x03") + ) + + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + configureStack func(*testing.T, *stack.Stack) + }{ + { + name: "Default", + configureStack: func(*testing.T, *stack.Stack) {}, + }, + { + name: "Spoofing", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetSpoofing(nicID, true); err != nil { + t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) + } + }, + }, + { + name: "Promiscuous", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetPromiscuousMode(nicID, true); err != nil { + t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, handleLocal := range []bool{true, false} { + t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + HandleLocal: handleLocal, + }) + + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) + + test.configureStack(t, s) + + r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) + } + defer r.Release() + + if r.LocalAddress != localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + + if n := ep.Drain(); n != 0 { + t.Fatalf("got ep.Drain() = %d, want = 0", n) + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: fakeTransNumber, + TTL: 123, + TOS: stack.DefaultTOS, + }, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.NewView(10).ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, _, _): %s", err) + } + if n := ep.Drain(); n != 1 { + t.Fatalf("got ep.Drain() = %d, want = 1", n) + } + }) + } + }) + } +} + func TestSpoofingWithAddress(t *testing.T) { localAddr := tcpip.Address("\x01") nonExistentLocalAddr := tcpip.Address("\x02") @@ -1451,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) { // testSendTo(t, s, remoteAddr, ep, nil) } -func verifyRoute(gotRoute, wantRoute stack.Route) error { +func verifyRoute(gotRoute, wantRoute *stack.Route) error { if gotRoute.LocalAddress != wantRoute.LocalAddress { return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) } if gotRoute.RemoteAddress != wantRoute.RemoteAddress { return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) } - if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { - return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want { + return fmt.Errorf("bad remote link address: got %s, want = %s", got, want) } if gotRoute.NextHop != wantRoute.NextHop { return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) @@ -1491,7 +1602,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1545,7 +1656,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1555,7 +1666,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1571,7 +1682,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -2108,88 +2219,6 @@ func TestNICStats(t *testing.T) { } } -func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") - - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } - - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } - - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { - t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { @@ -2377,9 +2406,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + AutoGenLinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, })}, } @@ -2472,8 +2501,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + AutoGenLinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, })}, } @@ -2506,9 +2535,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + AutoGenLinkLocal: true, + NDPDisp: &ndpDisp, })}, } @@ -3321,11 +3350,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { remNetSubnetBcast := remNetSubnet.Broadcast() tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedRoute stack.Route + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedLocalAddress tcpip.Address + expectedRemoteAddress tcpip.Address + expectedRemoteLinkAddress tcpip.LinkAddress + expectedNextHop tcpip.Address + expectedNetProto tcpip.NetworkProtocolNumber + expectedLoop stack.PacketLooping }{ // Broadcast to a locally attached subnet populates the broadcast MAC. { @@ -3340,14 +3374,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: ipv4SubnetBcast, - RemoteLinkAddress: header.EthernetBroadcastAddress, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4SubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: ipv4SubnetBcast, + expectedRemoteLinkAddress: header.EthernetBroadcastAddress, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut | stack.PacketLoop, }, // Broadcast to a locally attached /31 subnet does not populate the // broadcast MAC. @@ -3363,13 +3395,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet31Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix31.Address, - RemoteAddress: ipv4Subnet31Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet31Bcast, + expectedLocalAddress: ipv4AddrPrefix31.Address, + expectedRemoteAddress: ipv4Subnet31Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a locally attached /32 subnet does not populate the // broadcast MAC. @@ -3385,13 +3415,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet32Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix32.Address, - RemoteAddress: ipv4Subnet32Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet32Bcast, + expectedLocalAddress: ipv4AddrPrefix32.Address, + expectedRemoteAddress: ipv4Subnet32Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // IPv6 has no notion of a broadcast. { @@ -3406,13 +3434,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv6SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv6Addr.Address, - RemoteAddress: ipv6SubnetBcast, - NetProto: header.IPv6ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv6SubnetBcast, + expectedLocalAddress: ipv6Addr.Address, + expectedRemoteAddress: ipv6SubnetBcast, + expectedNetProto: header.IPv6ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a remote subnet in the route table is send to the next-hop // gateway. @@ -3429,14 +3455,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to an unknown subnet follows the default route. Note that this // is essentially just routing an unknown destination IP, because w/o any @@ -3454,14 +3478,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, } @@ -3490,10 +3512,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) + if err != nil { t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { - t.Errorf("route mismatch (-want +got):\n%s", diff) + } + if r.LocalAddress != test.expectedLocalAddress { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress) + } + if r.RemoteAddress != test.expectedRemoteAddress { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress) + } + if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { + t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) + } + if r.NextHop != test.expectedNextHop { + t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop) + } + if r.NetProto != test.expectedNetProto { + t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto) + } + if r.Loop != test.expectedLoop { + t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop) } }) } @@ -3672,3 +3711,515 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +func TestFindRouteWithForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Addr = tcpip.Address("\x01") + nic2Addr = tcpip.Address("\x02") + remoteAddr = tcpip.Address("\x03") + ) + + type netCfg struct { + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1Addr tcpip.Address + nic2Addr tcpip.Address + remoteAddr tcpip.Address + } + + fakeNetCfg := netCfg{ + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1Addr: nic1Addr, + nic2Addr: nic2Addr, + remoteAddr: remoteAddr, + } + + globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) + globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) + + ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: llAddr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: globalIPv6Addr1, + } + ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: llAddr1, + remoteAddr: llAddr2, + } + ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + } + + tests := []struct { + name string + + netCfg netCfg + forwardingEnabled bool + + addrNIC tcpip.NICID + localAddr tcpip.Address + + findRouteErr *tcpip.Error + dependentOnForwarding bool + }{ + { + name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: true, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) + } + + if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + } + + if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + } + + if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) + + r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if r != nil { + defer r.Release() + } + if err != test.findRouteErr { + t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + } + + if test.findRouteErr != nil { + return + } + + if r.LocalAddress != test.localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr) + } + if r.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr) + } + + if t.Failed() { + t.FailNow() + } + + // Sending a packet should always go through NIC2 since we only install a + // route to test.netCfg.remoteAddr through NIC2. + data := buffer.View([]byte{1, 2, 3, 4}) + if err := send(r, data); err != nil { + t.Fatalf("send(_, _): %s", err) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not sent through ep2") + } + if pkt.Route.LocalAddress != test.localAddr { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + } + if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) + } + + if !test.forwardingEnabled || !test.dependentOnForwarding { + return + } + + // Disabling forwarding when the route is dependent on forwarding being + // enabled should make the route invalid. + if err := s.SetForwarding(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + } + if err := send(r, data); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + if n := ep2.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep2", n) + } + }) + } +} + +func TestWritePacketToRemote(t *testing.T) { + const nicID = 1 + const MTU = 1280 + e := channel.New(1, MTU, linkAddr1) + s := stack.New(stack.Options{}) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("CreateNIC(%d) = %s", nicID, err) + } + tests := []struct { + name string + protocol tcpip.NetworkProtocolNumber + payload []byte + }{ + { + name: "SuccessIPv4", + protocol: header.IPv4ProtocolNumber, + payload: []byte{1, 2, 3, 4}, + }, + { + name: "SuccessIPv6", + protocol: header.IPv6ProtocolNumber, + payload: []byte{5, 6, 7, 8}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) + } + + pkt, ok := e.Read() + if got, want := ok, true; got != want { + t.Fatalf("e.Read() = %t, want %t", got, want) + } + if got, want := pkt.Proto, test.protocol; got != want { + t.Fatalf("pkt.Proto = %d, want %d", got, want) + } + if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) + } + if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) + } + }) + } + + t.Run("InvalidNICID", func(t *testing.T) { + if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + } + pkt, ok := e.Read() + if got, want := ok, false; got != want { + t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) + } + }) +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 35e5b1a2e..f183ec6e4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isInboundMulticastOrBroadcast(r) { - mpep.handlePacketAll(r, id, pkt) + if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { + mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { - queuedProtocol.QueuePacket(r, transEP, id, pkt) + queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() return } - transEP.HandlePacket(r, id, pkt) + transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } @@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T // based on endpoints IDs. It should only be instantiated via // newTransportDemuxer. type transportDemuxer struct { + stack *Stack + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints queuedProtocols map[protocolIDs]queuedTransportProtocol @@ -262,11 +264,12 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) + QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { d := &transportDemuxer{ + stack: stack, protocol: make(map[protocolIDs]*transportEndpoints), queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), } @@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) } else { - endpoint.HandlePacket(r, id, pkt.Clone()) + endpoint.HandlePacket(id, pkt.Clone()) } } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) + queuedProtocol.QueuePacket(endpoint, id, pkt) } else { - endpoint.HandlePacket(r, id, pkt) + endpoint.HandlePacket(id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() return false } // handlePacket takes ownership of pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(r, id, pkt.Clone()) + ep.handlePacket(id, pkt.Clone()) } - destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + destEPs[len(destEPs)-1].handlePacket(id, pkt) return true } @@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // destination address, then do nothing further and instruct the caller to do // the same. The network layer handles address validation for specified source // addresses. - if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { + if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { // TCP can only be used to communicate between a single source and a - // single destination; the addresses must be unicast. - r.Stats().TCP.InvalidSegmentsReceived.Increment() + // single destination; the addresses must be unicast.e + d.stack.stats.TCP.InvalidSegmentsReceived.Increment() return true } @@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() } return false } - ep.handlePacket(r, id, pkt) + ep.handlePacket(id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } @@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, pkt) + rawEP.HandlePacket(pkt.Clone()) foundRaw = true } eps.mu.RUnlock() @@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // findTransportEndpoint find a single endpoint that most closely matches the provided id. -func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil @@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[nicID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isInboundMulticastOrBroadcast(r *Route) bool { - return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) +func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { + return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) } func isSpecified(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 698c8609e..a692af20b 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -102,7 +102,6 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI // Initialize the IP header. ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TOS: 0x80, TotalLength: uint16(len(buf)), TTL: 65, @@ -142,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -308,9 +307,7 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { - t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) - } + ep.SocketOptions().SetReusePort(endpoint.reuse) bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 62ab6d92f..66eb562ba 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -28,7 +27,7 @@ import ( const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 + fakeTransHeaderLen int = 3 ) // fakeTransportEndpoint is a transport-layer protocol endpoint. It counts @@ -39,14 +38,18 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler proto *fakeTransportProtocol peerAddr tcpip.Address - route stack.Route + route *stack.Route uniqueID uint64 // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint + acceptQueue []*fakeTransportEndpoint + + // ops is used to set and get socket options. + ops tcpip.SocketOptions } func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { @@ -59,8 +62,14 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { + return &f.ops +} + func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep.ops.InitHandler(ep) + return ep } func (f *fakeTransportEndpoint) Abort() { @@ -100,8 +109,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. Currently not supported. @@ -109,21 +118,11 @@ func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Erro return tcpip.ErrInvalidEndpointState } -// SetSockOptBool sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - // SetSockOptInt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -186,7 +185,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai } a := f.acceptQueue[0] f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil + return a, nil, nil } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { @@ -201,7 +200,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { ); err != nil { return err } - f.acceptQueue = []fakeTransportEndpoint{} + f.acceptQueue = []*fakeTransportEndpoint{} return nil } @@ -213,20 +212,31 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) + if f.acceptQueue == nil { + return + } + + netHdr := pkt.NetworkHeader().View() + route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return } + route.ResolveWith(pkt.SourceLinkAddress()) + + ep := &fakeTransportEndpoint{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, + proto: f.proto, + peerAddr: route.RemoteAddress, + route: route, + } + ep.ops.InitHandler(ep) + f.acceptQueue = append(f.acceptQueue, ep) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { @@ -288,7 +298,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } @@ -544,87 +554,3 @@ func TestTransportOptions(t *testing.T) { t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: req.ToVectorisedView(), - })) - - aep, _, err := ep.Accept(nil) - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - p, ok := ep2.Read() - if !ok { - t.Fatal("Response packet not forwarded") - } - - nh := stack.PayloadSince(p.Pkt.NetworkHeader()) - if dst := nh[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := nh[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d77848d61..45fa62720 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -49,8 +49,9 @@ const ipv4AddressSize = 4 // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // -// Note: to support save / restore, it is important that all tcpip errors have -// distinct error messages. +// All errors must have unique msg strings. +// +// +stateify savable type Error struct { msg string @@ -247,6 +248,16 @@ func (a Address) WithPrefix() AddressWithPrefix { } } +// Unspecified returns true if the address is unspecified. +func (a Address) Unspecified() bool { + for _, b := range a { + if b != 0 { + return false + } + } + return true +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -356,10 +367,9 @@ func (s *Subnet) IsBroadcast(address Address) bool { return s.Prefix() <= 30 && s.Broadcast() == address } -// Equal returns true if s equals o. -// -// Needed to use cmp.Equal on Subnet as its fields are unexported. +// Equal returns true if this Subnet is equal to the given Subnet. func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. return s == o } @@ -482,6 +492,14 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is + // set. + HasOriginalDstAddress bool + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress FullAddress } // PacketOwner is used to get UID and GID of the packet. @@ -536,7 +554,7 @@ type Endpoint interface { // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. - Peek([][]byte) (int64, ControlMessages, *Error) + Peek([][]byte) (int64, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -594,10 +612,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt SettableSocketOption) *Error - // SetSockOptBool sets a socket option, for simple cases where a value - // has the bool type. - SetSockOptBool(opt SockOptBool, v bool) *Error - // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. SetSockOptInt(opt SockOptInt, v int) *Error @@ -605,10 +619,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt GettableSocketOption) *Error - // GetSockOptBool gets a socket option for simple cases where a return - // value has the bool type. - GetSockOptBool(SockOptBool) (bool, *Error) - // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. GetSockOptInt(SockOptInt) (int, *Error) @@ -635,6 +645,10 @@ type Endpoint interface { // LastError clears and returns the last error reported by the endpoint. LastError() *Error + + // SocketOptions returns the structure which contains all the socket + // level options. + SocketOptions() *SocketOptions } // LinkPacketInfo holds Link layer information for a received packet. @@ -691,80 +705,6 @@ type WriteOptions struct { Atomic bool } -// SockOptBool represents socket options which values have the bool type. -type SockOptBool int - -const ( - // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify - // whether datagram sockets are allowed to send packets to a broadcast - // address. - BroadcastOption SockOptBool = iota - - // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if - // data should be held until segments are full by the TCP transport - // protocol. - CorkOption - - // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if - // data should be sent out immediately by the transport protocol. For - // TCP, it determines if the Nagle algorithm is on or off. - DelayOption - - // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to - // specify whether TCP keepalive is enabled for this socket. - KeepaliveEnabledOption - - // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to - // specify whether multicast packets sent over a non-loopback interface - // will be looped back. - MulticastLoopOption - - // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify - // whether UDP checksum is disabled for this socket. - NoChecksumOption - - // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify - // whether SCM_CREDENTIALS socket control messages are enabled. - // - // Only supported on Unix sockets. - PasscredOption - - // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. - QuickAckOption - - // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to - // specify if the IPV6_TCLASS ancillary message is passed with incoming - // packets. - ReceiveTClassOption - - // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify - // if the TOS ancillary message is passed with incoming packets. - ReceiveTOSOption - - // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to - // specify if more inforamtion is provided with incoming packets such as - // interface index and address. - ReceiveIPPacketInfoOption - - // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to - // specify whether Bind() should allow reuse of local address. - ReuseAddressOption - - // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit - // multiple sockets to be bound to an identical socket address. - ReusePortOption - - // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify - // whether an IPv6 socket is to be restricted to sending and receiving - // IPv6 packets only. - V6OnlyOption - - // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw - // endpoint that all packets being written have an IP header and the - // endpoint should not attach an IP header. - IPHdrIncludedOption -) - // SockOptInt represents socket options which values have the int type. type SockOptInt int @@ -1156,14 +1096,6 @@ type RemoveMembershipOption MembershipOption func (*RemoveMembershipOption) isSettableSocketOption() {} -// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP out-of-band data is delivered along with the normal in-band data. -type OutOfBandInlineOption int - -func (*OutOfBandInlineOption) isGettableSocketOption() {} - -func (*OutOfBandInlineOption) isSettableSocketOption() {} - // SocketDetachFilterOption is used by SetSockOpt to detach a previously attached // classic BPF filter on a given endpoint. type SocketDetachFilterOption int @@ -1213,10 +1145,6 @@ type LingerOption struct { Timeout time.Duration } -func (*LingerOption) isGettableSocketOption() {} - -func (*LingerOption) isSettableSocketOption() {} - // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable @@ -1256,6 +1184,12 @@ func (r Route) String() string { return out.String() } +// Equal returns true if the given Route is equal to this Route. +func (r Route) Equal(to Route) bool { + // NOTE: This relies on the fact that r.Destination == to.Destination + return r == to +} + // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 @@ -1381,6 +1315,18 @@ type ICMPv6PacketStats struct { // RedirectMsg is the total number of ICMPv6 redirect message packets // counted. RedirectMsg *StatCounter + + // MulticastListenerQuery is the total number of Multicast Listener Query + // messages counted. + MulticastListenerQuery *StatCounter + + // MulticastListenerReport is the total number of Multicast Listener Report + // messages counted. + MulticastListenerReport *StatCounter + + // MulticastListenerDone is the total number of Multicast Listener Done + // messages counted. + MulticastListenerDone *StatCounter } // ICMPv4SentPacketStats collects outbound ICMPv4-specific stats. @@ -1422,6 +1368,10 @@ type ICMPv6SentPacketStats struct { type ICMPv6ReceivedPacketStats struct { ICMPv6PacketStats + // Unrecognized is the total number of ICMPv6 packets received that the + // transport layer does not know how to parse. + Unrecognized *StatCounter + // Invalid is the total number of ICMPv6 packets received that the // transport layer could not parse. Invalid *StatCounter @@ -1431,33 +1381,102 @@ type ICMPv6ReceivedPacketStats struct { RouterOnlyPacketsDroppedByHost *StatCounter } -// ICMPStats collects ICMP-specific stats (both v4 and v6). -type ICMPStats struct { +// ICMPv4Stats collects ICMPv4-specific stats. +type ICMPv4Stats struct { // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type // and a single count of packets which failed to write to the link // layer. - V4PacketsSent ICMPv4SentPacketStats + PacketsSent ICMPv4SentPacketStats // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4 // packet type and a single count of invalid packets received. - V4PacketsReceived ICMPv4ReceivedPacketStats + PacketsReceived ICMPv4ReceivedPacketStats +} +// ICMPv6Stats collects ICMPv6-specific stats. +type ICMPv6Stats struct { // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type // and a single count of packets which failed to write to the link // layer. - V6PacketsSent ICMPv6SentPacketStats + PacketsSent ICMPv6SentPacketStats // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6 // packet type and a single count of invalid packets received. - V6PacketsReceived ICMPv6ReceivedPacketStats + PacketsReceived ICMPv6ReceivedPacketStats +} + +// ICMPStats collects ICMP-specific stats (both v4 and v6). +type ICMPStats struct { + // V4 contains the ICMPv4-specifics stats. + V4 ICMPv4Stats + + // V6 contains the ICMPv4-specifics stats. + V6 ICMPv6Stats +} + +// IGMPPacketStats enumerates counts for all IGMP packet types. +type IGMPPacketStats struct { + // MembershipQuery is the total number of Membership Query messages counted. + MembershipQuery *StatCounter + + // V1MembershipReport is the total number of Version 1 Membership Report + // messages counted. + V1MembershipReport *StatCounter + + // V2MembershipReport is the total number of Version 2 Membership Report + // messages counted. + V2MembershipReport *StatCounter + + // LeaveGroup is the total number of Leave Group messages counted. + LeaveGroup *StatCounter +} + +// IGMPSentPacketStats collects outbound IGMP-specific stats. +type IGMPSentPacketStats struct { + IGMPPacketStats + + // Dropped is the total number of IGMP packets dropped. + Dropped *StatCounter +} + +// IGMPReceivedPacketStats collects inbound IGMP-specific stats. +type IGMPReceivedPacketStats struct { + IGMPPacketStats + + // Invalid is the total number of IGMP packets received that IGMP could not + // parse. + Invalid *StatCounter + + // ChecksumErrors is the total number of IGMP packets dropped due to bad + // checksums. + ChecksumErrors *StatCounter + + // Unrecognized is the total number of unrecognized messages counted, these + // are silently ignored for forward-compatibilty. + Unrecognized *StatCounter +} + +// IGMPStats colelcts IGMP-specific stats. +type IGMPStats struct { + // IGMPSentPacketStats contains counts of sent packets by IGMP packet type + // and a single count of invalid packets received. + PacketsSent IGMPSentPacketStats + + // IGMPReceivedPacketStats contains counts of received packets by IGMP packet + // type and a single count of invalid packets received. + PacketsReceived IGMPReceivedPacketStats } // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // PacketsReceived is the total number of IP packets received from the - // link layer in nic.DeliverNetworkPacket. + // link layer. PacketsReceived *StatCounter + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived *StatCounter + // InvalidDestinationAddressesReceived is the total number of IP packets // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived *StatCounter @@ -1496,6 +1515,15 @@ type IPStats struct { // IPTablesOutputDropped is the total number of IP packets dropped in // the Output chain. IPTablesOutputDropped *StatCounter + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived *StatCounter + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived *StatCounter + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived *StatCounter } // TCPStats collects TCP-specific stats. @@ -1644,6 +1672,9 @@ type Stats struct { // ICMP breaks out ICMP-specific stats (both v4 and v6). ICMP ICMPStats + // IGMP breaks out IGMP-specific stats. + IGMP IGMPStats + // IP breaks out IP-specific stats (both v4 and v6). IP IPStats diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 1c8e2bc34..c461da137 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -226,3 +226,47 @@ func TestAddressWithPrefixSubnet(t *testing.T) { } } } + +func TestAddressUnspecified(t *testing.T) { + tests := []struct { + addr Address + unspecified bool + }{ + { + addr: "", + unspecified: true, + }, + { + addr: "\x00", + unspecified: true, + }, + { + addr: "\x01", + unspecified: false, + }, + { + addr: "\x00\x00", + unspecified: true, + }, + { + addr: "\x01\x00", + unspecified: false, + }, + { + addr: "\x00\x01", + unspecified: false, + }, + { + addr: "\x01\x01", + unspecified: false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) { + if got := test.addr.Unspecified(); got != test.unspecified { + t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 34aab32d0..800025fb9 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -10,6 +10,7 @@ go_test( "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", + "route_test.go", ], deps = [ "//pkg/tcpip", @@ -24,6 +25,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 0dcef7b04..39343b966 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -33,11 +33,6 @@ import ( func TestForwarding(t *testing.T) { const ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") - routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - host1NICID = 1 routerNICID1 = 2 routerNICID2 = 3 @@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) { } }, }, + { + name: "IPv4 host2 server with routerNIC1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host2IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + { + name: "IPv6 routerNIC2 server with host1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host1IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, } for _, test := range tests { @@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -202,19 +229,6 @@ func TestForwarding(t *testing.T) { t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) } - if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) } @@ -321,12 +335,8 @@ func TestForwarding(t *testing.T) { if err == tcpip.ErrNoLinkAddress { // Wait for link resolution to complete. <-ch - n, _, err = ep.Write(dataPayload, wOpts) - } else if err != nil { - t.Fatalf("ep.Write(_, _): %s", err) } - if err != nil { t.Fatalf("ep.Write(_, _): %s", err) } @@ -343,7 +353,6 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress v, _, err := ep.Read(&addr) if err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 6ddcda70c..bf8a1241f 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -32,32 +32,36 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +const ( + linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +) - host1IPv4Addr = tcpip.ProtocolAddress{ +var ( + ipv4Addr1 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - host2IPv4Addr = tcpip.ProtocolAddress{ + ipv4Addr2 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr1 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr2 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), @@ -89,7 +93,7 @@ func TestPing(t *testing.T) { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, - remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) @@ -104,7 +108,7 @@ func TestPing(t *testing.T) { name: "IPv6 Ping", transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, - remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) @@ -127,7 +131,7 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) + host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -136,43 +140,36 @@ func TestPing(t *testing.T) { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } - if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, tcpip.Route{ - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, }) host2Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, tcpip.Route{ - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, }) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index e8caf09ba..baaa741cd 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -70,8 +71,8 @@ func TestInitialLoopbackAddresses(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDispatcher{}, - AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDispatcher{}, + AutoGenLinkLocal: true, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(nicID tcpip.NICID, nicName string) string { t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) @@ -93,9 +94,10 @@ func TestInitialLoopbackAddresses(t *testing.T) { } } -// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address. -func TestLoopbackAcceptAllInSubnet(t *testing.T) { +// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address and UDP +// traffic is sent/received correctly. +func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { const ( nicID = 1 localPort = 80 @@ -107,7 +109,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr, } - ipv4Bytes := []byte(ipv4Addr.Address) + ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) ipv4Bytes[len(ipv4Bytes)-1]++ otherIPv4Address := tcpip.Address(ipv4Bytes) @@ -129,7 +131,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { { name: "IPv4 bind to wildcard and send to assigned address", addAddress: ipv4ProtocolAddress, - dstAddr: ipv4Addr.Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, expectRx: true, }, { @@ -148,7 +150,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { name: "IPv4 bind to other subnet-local address and send to assigned address", addAddress: ipv4ProtocolAddress, bindAddr: otherIPv4Address, - dstAddr: ipv4Addr.Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, expectRx: false, }, { @@ -161,7 +163,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { { name: "IPv4 bind to assigned address and send to other subnet-local address", addAddress: ipv4ProtocolAddress, - bindAddr: ipv4Addr.Address, + bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, dstAddr: otherIPv4Address, expectRx: false, }, @@ -204,7 +206,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { }, }) - wq := waiter.Queue{} + var wq waiter.Queue rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) @@ -236,13 +238,17 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) } - if gotPayload, _, err := rep.Read(nil); test.expectRx { + var addr tcpip.FullAddress + if gotPayload, _, err := rep.Read(&addr); test.expectRx { if err != nil { - t.Fatalf("reep.Read(nil): %s", err) + t.Fatalf("reep.Read(_): %s", err) } if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + if addr.Addr != test.addAddress.AddressWithPrefix.Address { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) + } } else { if err != tcpip.ErrWouldBlock { t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) @@ -312,3 +318,168 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) } } + +// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address and TCP +// traffic is sent/received correctly. +func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { + const ( + nicID = 1 + localPort = 80 + ) + + ipv4ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8 + ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) + ipv4Bytes[len(ipv4Bytes)-1]++ + otherIPv4Address := tcpip.Address(ipv4Bytes) + + ipv6ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + } + ipv6Bytes := []byte(ipv6Addr.Address) + ipv6Bytes[len(ipv6Bytes)-1]++ + otherIPv6Address := tcpip.Address(ipv6Bytes) + + tests := []struct { + name string + addAddress tcpip.ProtocolAddress + bindAddr tcpip.Address + dstAddr tcpip.Address + expectAccept bool + }{ + { + name: "IPv4 bind to wildcard and send to assigned address", + addAddress: ipv4ProtocolAddress, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + expectAccept: true, + }, + { + name: "IPv4 bind to wildcard and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + dstAddr: otherIPv4Address, + expectAccept: true, + }, + { + name: "IPv4 bind to wildcard send to other address", + addAddress: ipv4ProtocolAddress, + dstAddr: remoteIPv4Addr, + expectAccept: false, + }, + { + name: "IPv4 bind to other subnet-local address and send to assigned address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + expectAccept: false, + }, + { + name: "IPv4 bind and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: otherIPv4Address, + expectAccept: true, + }, + { + name: "IPv4 bind to assigned address and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + dstAddr: otherIPv4Address, + expectAccept: false, + }, + + { + name: "IPv6 bind and send to assigned address", + addAddress: ipv6ProtocolAddress, + bindAddr: ipv6Addr.Address, + dstAddr: ipv6Addr.Address, + expectAccept: true, + }, + { + name: "IPv6 bind to wildcard and send to other subnet-local address", + addAddress: ipv6ProtocolAddress, + dstAddr: otherIPv6Address, + expectAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer listeningEndpoint.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := listeningEndpoint.Bind(bindAddr); err != nil { + t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err) + } + + if err := listeningEndpoint.Listen(1); err != nil { + t.Fatalf("listeningEndpoint.Listen(1): %s", err) + } + + connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer connectingEndpoint.Close() + + connectAddr := tcpip.FullAddress{ + Addr: test.dstAddr, + Port: localPort, + } + if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } + + if !test.expectAccept { + if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + return + } + + // Wait for the listening endpoint to be "readable". That is, wait for a + // new connection. + <-ch + var addr tcpip.FullAddress + if _, _, err := listeningEndpoint.Accept(&addr); err != nil { + t.Fatalf("listeningEndpoint.Accept(nil): %s", err) + } + if addr.Addr != test.addAddress.AddressWithPrefix.Address { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index f1028823b..2e59f6a42 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -73,7 +73,6 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetChecksum(^header.Checksum(pkt, 0)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(totalLen), Protocol: uint8(icmp.ProtocolNumber4), TTL: ttl, @@ -97,11 +96,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -244,7 +243,6 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(totalLen), Protocol: uint8(udp.ProtocolNumber), TTL: ttl, @@ -274,11 +272,11 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: uint16(payloadLen), + TransportProtocol: udp.ProtocolNumber, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -409,7 +407,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - wq := waiter.Queue{} + var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) @@ -447,8 +445,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) { loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") ) - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) - tests := []struct { name string broadcastAddr tcpip.Address @@ -492,29 +488,30 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, }) + type endpointAndWaiter struct { + ep tcpip.Endpoint + ch chan struct{} + } + var eps []endpointAndWaiter // We create endpoints that bind to both the wildcard address and the // broadcast address to make sure both of these types of "broadcast // interested" endpoints receive broadcast packets. - wq := waiter.Queue{} - var eps []tcpip.Endpoint for _, bindWildcard := range []bool{false, true} { // Create multiple endpoints for each type of "broadcast interested" // endpoint so we can test that all endpoints receive the broadcast // packet. for i := 0; i < 2; i++ { + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) - } - - if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) - } + ep.SocketOptions().SetReuseAddress(true) + ep.SocketOptions().SetBroadcast(true) bindAddr := tcpip.FullAddress{Port: localPort} if bindWildcard { @@ -528,7 +525,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } } - eps = append(eps, ep) + eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) } } @@ -539,14 +536,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - if n, _, err := wep.Write(data, writeOpts); err != nil { + data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) + if n, _, err := wep.ep.Write(data, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) } for j, rep := range eps { - if gotPayload, _, err := rep.Read(nil); err != nil { + // Wait for the endpoint to become readable. + <-rep.ch + + if gotPayload, _, err := rep.ep.Read(nil); err != nil { t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go new file mode 100644 index 000000000..02fc47015 --- /dev/null +++ b/pkg/tcpip/tests/integration/route_test.go @@ -0,0 +1,388 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TestLocalPing tests pinging a remote that is local the stack. +// +// This tests that a local route is created and packets do not leave the stack. +func TestLocalPing(t *testing.T) { + const ( + nicID = 1 + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo + // request/reply packets. + icmpDataOffset = 8 + ) + + channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } + channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { + channelEP := e.(*channel.Endpoint) + if n := channelEP.Drain(); n != 0 { + t.Fatalf("got channelEP.Drain() = %d, want = 0", n) + } + } + + ipv4ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) + hdr.SetType(header.ICMPv4Echo) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + ipv6ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} + hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) + hdr.SetType(header.ICMPv6EchoRequest) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + tests := []struct { + name string + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber + linkEndpoint func() stack.LinkEndpoint + localAddr tcpip.Address + icmpBuf func(*testing.T) buffer.View + expectedConnectErr *tcpip.Error + checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) + }{ + { + name: "IPv4 loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: ipv4Loopback, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: header.IPv6Loopback, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv4Addr.Address, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv6Addr.Address, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv4 loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() + + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } + + if test.expectedConnectErr != nil { + return + } + + payload := tcpip.SlicePayload(test.icmpBuf(t)) + var wOpts tcpip.WriteOptions + if n, _, err := ep.Write(payload, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } + + // Wait for the endpoint to become readable. + <-ch + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } +} + +// TestLocalUDP tests sending UDP packets between two endpoints that are local +// to the stack. +// +// This tests that that packets never leave the stack and the addresses +// used when sending a packet. +func TestLocalUDP(t *testing.T) { + const ( + nicID = 1 + ) + + tests := []struct { + name string + canBePrimaryAddr tcpip.ProtocolAddress + firstPrimaryAddr tcpip.ProtocolAddress + }{ + { + name: "IPv4", + canBePrimaryAddr: ipv4Addr1, + firstPrimaryAddr: ipv4Addr2, + }, + { + name: "IPv6", + canBePrimaryAddr: ipv6Addr1, + firstPrimaryAddr: ipv6Addr2, + }, + } + + subTests := []struct { + name string + addAddress bool + expectedWriteErr *tcpip.Error + }{ + { + name: "Unassigned local address", + addAddress: false, + expectedWriteErr: tcpip.ErrNoRoute, + }, + { + name: "Assigned local address", + addAddress: true, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + s := stack.New(stackOpts) + ep := channel.New(1, header.IPv6MinimumMTU, "") + + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if subTest.addAddress { + if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + } + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer server.Close() + + bindAddr := tcpip.FullAddress{Port: 80} + if err := server.Bind(bindAddr); err != nil { + t.Fatalf("server.Bind(%#v): %s", bindAddr, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventIn) + client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer client.Close() + + serverAddr := tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + Port: 80, + } + + clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &serverAddr, + } + if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) + } else if subTest.expectedWriteErr != nil { + // Nothing else to test if we expected not to be able to send the + // UDP packet. + return + } else if n != int64(len(clientPayload)) { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload)) + } + } + + // Wait for the server endpoint to become readable. + <-serverCH + + var clientAddr tcpip.FullAddress + if v, _, err := server.Read(&clientAddr); err != nil { + t.Fatalf("server.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + } + if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { + t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + } + if t.Failed() { + t.FailNow() + } + } + + serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &clientAddr, + } + if n, _, err := server.Write(serverPayload, wOpts); err != nil { + t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) + } else if n != int64(len(serverPayload)) { + t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload)) + } + } + + // Wait for the client endpoint to become readable. + <-clientCH + + var gotServerAddr tcpip.FullAddress + if v, _, err := client.Read(&gotServerAddr); err != nil { + t.Fatalf("client.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + } + if gotServerAddr.Addr != serverAddr.Addr { + t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + } + if t.Failed() { + t.FailNow() + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 41eb0ca44..74fe19e98 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -49,6 +49,7 @@ const ( // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. @@ -71,18 +72,19 @@ type endpoint struct { // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // ops is used to get socket level options. + ops tcpip.SocketOptions } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return &endpoint{ + ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, @@ -93,7 +95,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), - }, nil + } + ep.ops.InitHandler(ep) + return ep, nil } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -126,7 +130,10 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. e.state = stateClosed @@ -139,6 +146,7 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } @@ -264,26 +272,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } } - var route *stack.Route - if to == nil { - route = &e.route - - if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, - // given that it may need to change in Route.Resolve() - // call below. - e.mu.RUnlock() - defer e.mu.RLock() - - e.mu.Lock() - defer e.mu.Unlock() - - // Recheck state after lock was re-acquired. - if e.state != stateConnected { - return 0, nil, tcpip.ErrInvalidEndpointState - } - } - } else { + route := e.route + if to != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC @@ -307,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r } if route.IsResolutionRequired() { @@ -340,26 +330,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.SocketDetachFilterOption: - return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - } - return nil -} - -// SetSockOptBool sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return nil } @@ -375,17 +351,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption: - return false, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -423,16 +388,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -755,7 +711,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -800,7 +756,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, }, } @@ -853,3 +809,8 @@ func (*endpoint) Wait() {} func (*endpoint) LastError() *tcpip.Error { return nil } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 87d510f96..3820e5dc7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 072601d2d..9faab4b9e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -60,6 +60,8 @@ type packet struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` @@ -83,12 +85,13 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // NewEndpoint returns a new packet endpoint. @@ -104,6 +107,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, } + ep.ops.InitHandler(ep) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -200,8 +204,8 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha } // Peek implements tcpip.Endpoint.Peek. -func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be @@ -300,26 +304,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - ep.mu.Lock() - ep.linger = *v - ep.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption -} - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -375,21 +368,7 @@ func (ep *endpoint) LastError() *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - ep.mu.Lock() - *o = ep.linger - ep.mu.Unlock() - return nil - - default: - return tcpip.ErrNotSupported - } -} - -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrNotSupported + return tcpip.ErrNotSupported } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -543,4 +522,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } +// SetOwner implements tcpip.Endpoint.SetOwner. func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { + return &ep.ops +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e37c00523..87c60bdab 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -58,12 +58,13 @@ type rawPacket struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool - hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -82,13 +83,14 @@ type endpoint struct { bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // NewEndpoint returns a raw endpoint for the given protocols. @@ -111,8 +113,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, - hdrIncluded: !associated, } + e.ops.InitHandler(e) + e.ops.SetHeaderIncluded(!associated) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -167,9 +170,11 @@ func (e *endpoint) Close() { e.rcvList.Remove(e.rcvList.Front()) } - if e.connected { + e.connected = false + + if e.route != nil { e.route.Release() - e.connected = false + e.route = nil } e.closed = true @@ -220,6 +225,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrInvalidOptionValue } + if opts.To != nil { + // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + return 0, nil, tcpip.ErrInvalidOptionValue + } + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -263,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -293,7 +305,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if e.route.IsResolutionRequired() { - savedRoute := &e.route + savedRoute := e.route // Promote lock to exclusive if using a shared route, // given that it may need to change in finishWrite. e.mu.RUnlock() @@ -301,7 +313,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Make sure that the route didn't change during the // time we didn't hold the lock. - if !e.connected || savedRoute != &e.route { + if !e.connected || savedRoute != e.route { e.mu.Unlock() return 0, nil, tcpip.ErrInvalidEndpointState } @@ -311,7 +323,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return n, ch, err } - n, ch, err := e.finishWrite(payloadBytes, &e.route) + n, ch, err := e.finishWrite(payloadBytes, e.route) e.mu.RUnlock() return n, ch, err } @@ -332,7 +344,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - n, ch, err := e.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, route) route.Release() e.mu.RUnlock() return n, ch, err @@ -353,7 +365,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), }) @@ -379,8 +391,8 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -390,6 +402,11 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + return tcpip.ErrAddressFamilyNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -513,33 +530,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - e.hdrIncluded = v - e.mu.Unlock() - return nil - } - return tcpip.ErrUnknownProtocolOption -} - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -586,33 +585,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption: - return false, nil - - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - v := e.hdrIncluded - e.mu.Unlock() - return v, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -646,7 +619,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -671,14 +644,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { return } + remoteAddr := pkt.Network().SourceAddress() + if e.bound { // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != route.NICID() { + if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() return } @@ -686,7 +661,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // If connected, only accept packets from the remote address we // connected to. - if e.connected && e.route.RemoteAddress != route.RemoteAddress { + if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() return } @@ -696,8 +671,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // Push new packet into receive list and increment the buffer size. packet := &rawPacket{ senderAddr: tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress, + NIC: pkt.NICID, + Addr: remoteAddr, }, } @@ -751,6 +726,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} +// LastError implements tcpip.Endpoint.LastError. func (*endpoint) LastError() *tcpip.Error { return nil } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 33bfb56cd..4a7e1c039 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -37,57 +37,63 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { } // beforeSave is invoked by stateify. -func (ep *endpoint) beforeSave() { +func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming // packets. - ep.rcvMu.Lock() + e.rcvMu.Lock() } // saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 + e.rcvBufSizeMax = 0 // Release the lock acquired in beforeSave() so regular endpoint closing // logic can proceed after save. - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return max } // loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max } // afterLoad is invoked by stateify. -func (ep *endpoint) afterLoad() { - stack.StackFromEnv.RegisterRestoredEndpoint(ep) +func (e *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. -func (ep *endpoint) Resume(s *stack.Stack) { - ep.stack = s +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s // If the endpoint is connected, re-connect. - if ep.connected { + if e.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) + // TODO(gvisor.dev/issue/4906): Properly restore the route with the right + // remote address. We used to pass e.remote.RemoteAddress which was + // effectively the empty address but since moving e.route to hold a pointer + // to a route instead of the route by value, we pass the empty address + // directly. Obviously this was always wrong since we should provide the + // remote address we were connected to, to properly restore the route. + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false) if err != nil { panic(err) } } // If the endpoint is bound, re-bind. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { + if e.bound { + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if ep.associated { - if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + if e.associated { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 518449602..cf232b508 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "more_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -45,7 +45,9 @@ go_library( "rcv.go", "rcv_state.go", "reno.go", + "reno_recovery.go", "sack.go", + "sack_recovery.go", "sack_scoreboard.go", "segment.go", "segment_heap.go", @@ -91,7 +93,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], - shard_count = 10, + shard_count = more_shards, deps = [ ":tcp", "//pkg/rand", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b706438bd..3e1041cbe 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { - netProto = s.route.NetProto + netProto = s.netProto } + + route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(s.remoteLinkAddr) + n := newEndpoint(l.stack, netProto, queue) - n.v6only = l.v6Only + n.ops.SetV6Only(l.v6Only) n.ID = s.id - n.boundNICID = s.route.NICID() - n.route = s.route.Clone() - n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.boundNICID = s.nicID + n.route = route + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} n.rcvBufSize = int(l.rcvWnd) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -225,18 +232,25 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - return n + return n, nil } -// createEndpointAndPerformHandshake creates a new endpoint in connected state -// and then performs the TCP 3-way handshake. +// startHandshake creates a new endpoint in connecting state and then sends +// the SYN-ACK for the TCP 3-way handshake. It returns the state of the +// handshake in progress, which includes the new endpoint in the SYN-RCVD +// state. // -// The new endpoint is returned with e.mu held. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +// On success, a handshake h is returned with h.ep.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + if err != nil { + return nil, err + } // Lock the endpoint before registering to ensure that no out of // band changes are possible due to incoming packets etc till @@ -247,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) if l.listenEP != nil { - l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { - l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() @@ -268,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - l.listenEP.mu.Unlock() - } + l.removePendingEndpoint(ep) return nil, tcpip.ErrConnectionAborted } deferAccept = l.listenEP.deferAccept - l.listenEP.mu.Unlock() } // Register new endpoint so that packets are routed to it. @@ -296,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.isRegistered = true - // Perform the 3-way handshake. - h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) - if err := h.execute(); err != nil { - ep.mu.Unlock() - ep.Close() - ep.notifyAborted() - - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - - ep.drainClosingSegmentQueue() - + // Initialize and start the handshake. + h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + if err := h.start(); err != nil { + l.cleanupFailedHandshake(h) return nil, err } - ep.isConnectNotified = true + return h, nil +} - // Update the receive window scaling. We can't do it before the - // handshake because it's possible that the peer doesn't support window - // scaling. - ep.rcv.rcvWndScale = h.effectiveRcvWndScale() +// performHandshake performs a TCP 3-way handshake. On success, the new +// established endpoint is returned with e.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { + h, err := l.startHandshake(s, opts, queue, owner) + if err != nil { + return nil, err + } + ep := h.ep + if err := h.complete(); err != nil { + ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() + ep.stats.FailedConnectionAttempts.Increment() + l.cleanupFailedHandshake(h) + return nil, err + } + l.cleanupCompletedHandshake(h) return ep, nil } @@ -344,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupFailedHandshake(h *handshake) { + e := h.ep + e.mu.Unlock() + e.Close() + e.notifyAborted() + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.drainClosingSegmentQueue() + e.h = nil +} + +// cleanupCompletedHandshake transfers any state from the completed handshake to +// the new endpoint. +// +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupCompletedHandshake(h *handshake) { + e := h.ep + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.isConnectNotified = true + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + + // Clean up handshake state stored in the endpoint so that it can be GCed. + e.h = nil +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. @@ -423,26 +469,40 @@ func (e *endpoint) notifyAborted() { // // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer ctx.synRcvdCount.dec() - defer func() { - e.mu.Lock() - e.decSynRcvdCount() - e.mu.Unlock() - }() +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + e.synRcvdCount-- + return err } - ctx.removePendingEndpoint(n) - n.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(n) + go func() { + defer ctx.synRcvdCount.dec() + if err := h.complete(); err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + return + } + ctx.cleanupCompletedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(h.ep) + }() // S/R-SAFE: synRcvdCount is the barrier. + + return nil } func (e *endpoint) incSynRcvdCount() bool { @@ -455,10 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool { return canInc } -func (e *endpoint) decSynRcvdCount() { - e.synRcvdCount-- -} - func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) @@ -468,7 +524,9 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -478,8 +536,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } switch { @@ -492,14 +549,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // backlog. if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() - go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. - return + _ = e.handleSynSegment(ctx, s, &opts) + return nil } ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } else { // If cookies are in use but the endpoint accept queue // is full then drop the syn. @@ -507,10 +564,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Send SYN without window scaling because we currently // don't encode this information in the cookie. // @@ -524,9 +588,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TS: opts.TS, TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, s.route), + MSS: calculateAdvertisedMSS(e.userMSS, route), } - e.sendSynTCP(&s.route, tcpFields{ + fields := tcpFields{ id: s.id, ttl: e.ttl, tos: e.sendTOS, @@ -534,8 +598,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { seq: cookie, ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, - }, synOpts) + } + if err := e.sendSynTCP(route, fields, synOpts); err != nil { + return err + } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil } case (s.flags & header.TCPFlagAck) != 0: @@ -548,7 +616,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } if !ctx.synRcvdCount.synCookiesInUse() { @@ -567,8 +635,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } iss := s.ackNumber - 1 @@ -588,7 +655,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -609,7 +676,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + if err != nil { + return err + } n.mu.Lock() @@ -623,7 +693,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return nil } // Register new endpoint so that packets are routed to it. @@ -633,7 +703,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return err } n.isRegistered = true @@ -671,14 +741,18 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) + return nil + + default: + return nil } } // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. -func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() - v6Only := e.v6only + v6Only := e.ops.GetV6Only() ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) defer func() { @@ -687,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // to the endpoint. e.setEndpointState(StateClose) - // close any endpoints in SYN-RCVD state. + // Close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() // Do cleanup if needed. @@ -715,12 +789,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { - return nil + return } if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { s := e.segmentQueue.dequeue() - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } close(e.drainDone) @@ -739,7 +815,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { break } - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 0aaef495d..c944dccc0 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,6 +16,7 @@ package tcp import ( "encoding/binary" + "math" "time" "gvisor.dev/gvisor/pkg/rand" @@ -102,21 +103,26 @@ type handshake struct { // been received. This is required to stop retransmitting the // original SYN-ACK when deferAccept is enabled. acked bool + + // sendSYNOpts is the cached values for the SYN options to be sent. + sendSYNOpts header.TCPSynOptions } -func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { - h := handshake{ - ep: ep, +func (e *endpoint) newHandshake() *handshake { + h := &handshake{ + ep: e, active: true, - rcvWnd: rcvWnd, - rcvWndScale: ep.rcvWndScaleForHandshake(), + rcvWnd: seqnum.Size(e.initialReceiveWindow()), + rcvWndScale: e.rcvWndScaleForHandshake(), } h.resetState() + // Store reference to handshake state in endpoint. + e.h = h return h } -func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { - h := newHandshake(ep, rcvWnd) +func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { + h := e.newHandshake() h.resetToSynRcvd(isn, irs, opts, deferAccept) return h } @@ -128,7 +134,7 @@ func FindWndScale(wnd seqnum.Size) int { return 0 } - max := seqnum.Size(0xffff) + max := seqnum.Size(math.MaxUint16) s := 0 for wnd > max && s < header.MaxWndScale { s++ @@ -293,9 +299,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { MSS: amss, } if ttl == 0 { - ttl = s.route.DefaultTTL() + ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -356,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -491,17 +497,20 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } } // Wait for notification. - index, _ = s.Fetch(true) + h.ep.mu.Unlock() + index, _ = s.Fetch(true /* block */) + h.ep.mu.Lock() } } -// execute executes the TCP 3-way handshake. -func (h *handshake) execute() *tcpip.Error { +// start resolves the route if necessary and sends the first +// SYN/SYN-ACK. +func (h *handshake) start() *tcpip.Error { if h.ep.route.IsResolutionRequired() { if err := h.resolveRoute(); err != nil { return err @@ -509,19 +518,7 @@ func (h *handshake) execute() *tcpip.Error { } h.startTime = time.Now() - // Initialize the resend timer. - resendWaker := sleep.Waker{} - timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, resendWaker.Assert) - defer rt.Stop() - - // Set up the wakers. - s := sleep.Sleeper{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled @@ -529,10 +526,6 @@ func (h *handshake) execute() *tcpip.Error { sackEnabled = false } - // Send the initial SYN segment and loop until the handshake is - // completed. - h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) - synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, @@ -542,9 +535,8 @@ func (h *handshake) execute() *tcpip.Error { MSS: h.ep.amss, } - // Execute is also called in a listen context so we want to make sure we - // only send the TS/SACK option when we received the TS/SACK in the - // initial SYN. + // start() is also called in a listen context so we want to make sure we only + // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) @@ -555,7 +547,8 @@ func (h *handshake) execute() *tcpip.Error { } } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.sendSYNOpts = synOpts + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -564,19 +557,37 @@ func (h *handshake) execute() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) + return nil +} + +// complete completes the TCP 3-way handshake initiated by h.start(). +func (h *handshake) complete() *tcpip.Error { + // Set up the wakers. + s := sleep.Sleeper{} + resendWaker := sleep.Waker{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + // Initialize the resend timer. + timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + if err != nil { + return err + } + defer timer.stop() for h.state != handshakeCompleted { + // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held + // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true) + index, _ := s.Fetch(true /* block */) h.ep.mu.Lock() switch index { case wakerForResend: - timeOut *= 2 - if timeOut > MaxRTO { - return tcpip.ErrTimeout + if err := timer.reset(); err != nil { + return err } - rt.Reset(timeOut) // Resend the SYN/SYN-ACK only if the following conditions hold. // - It's an active handshake (deferAccept does not apply) // - It's a passive handshake and we have not yet got the final-ACK. @@ -586,7 +597,7 @@ func (h *handshake) execute() *tcpip.Error { // the connection with another ACK or data (as ACKs are never // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -594,7 +605,7 @@ func (h *handshake) execute() *tcpip.Error { seq: h.iss, ack: h.ackNum, rcvWnd: h.rcvWnd, - }, synOpts) + }, h.sendSYNOpts) } case wakerForNotification: @@ -620,9 +631,8 @@ func (h *handshake) execute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } - case wakerForNewSegment: if err := h.processSegments(); err != nil { return err @@ -633,6 +643,34 @@ func (h *handshake) execute() *tcpip.Error { return nil } +type backoffTimer struct { + timeout time.Duration + maxTimeout time.Duration + t *time.Timer +} + +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { + if timeout > maxTimeout { + return nil, tcpip.ErrTimeout + } + bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} + bt.t = time.AfterFunc(timeout, f) + return bt, nil +} + +func (bt *backoffTimer) reset() *tcpip.Error { + bt.timeout *= 2 + if bt.timeout > MaxRTO { + return tcpip.ErrTimeout + } + bt.t.Reset(bt.timeout) + return nil +} + +func (bt *backoffTimer) stop() { + bt.t.Stop() +} + func parseSynSegmentOptions(s *segment) header.TCPSynOptions { synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) if synOpts.TS { @@ -767,7 +805,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // TCP header, then the kernel calculate a checksum of the // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) - } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + } else if r.RequiresTXTransportChecksum() { xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } @@ -781,8 +819,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso data = data.Clone(nil) optLen := len(tf.opts) - if tf.rcvWnd > 0xffff { - tf.rcvWnd = 0xffff + if tf.rcvWnd > math.MaxUint16 { + tf.rcvWnd = math.MaxUint16 } mss := int(gso.MSS) @@ -826,8 +864,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // network endpoint and under the provided identity. func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) - if tf.rcvWnd > 0xffff { - tf.rcvWnd = 0xffff + if tf.rcvWnd > math.MaxUint16 { + tf.rcvWnd = math.MaxUint16 } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { @@ -902,7 +940,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, tcpFields{ + err := e.sendTCP(e.route, tcpFields{ id: e.ID, ttl: e.ttl, tos: e.sendTOS, @@ -963,7 +1001,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) - e.HardError = err + e.hardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk @@ -1040,13 +1078,13 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) - if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } if ep == nil { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) s.decRef() return } @@ -1102,7 +1140,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.HardError = tcpip.ErrAborted + e.hardError = tcpip.ErrAborted e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1247,7 +1285,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() - if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + if !e.SocketOptions().GetKeepAlive() || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } @@ -1284,7 +1322,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return @@ -1314,7 +1352,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ epilogue := func() { // e.mu is expected to be hold upon entering this section. - if e.snd != nil { e.snd.resendTimer.cleanup() } @@ -1338,20 +1375,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if handshake { - // This is an active connection, so we must initiate the 3-way - // handshake, and then inform potential waiters about its - // completion. - initialRcvWnd := e.initialReceiveWindow() - h := newHandshake(e, seqnum.Size(initialRcvWnd)) - h.ep.setEndpointState(StateSynSent) - - if err := h.execute(); err != nil { + if err := e.h.complete(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.setEndpointState(StateError) - e.HardError = err + e.hardError = err e.workerCleanup = true // Lock released below. @@ -1360,13 +1390,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } - e.keepalive.timer.init(&e.keepalive.waker) - defer e.keepalive.timer.cleanup() - drained := e.drainDone != nil if drained { close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } // Set up the functions that will be called when the main protocol loop @@ -1445,7 +1474,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } } @@ -1535,7 +1563,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() // We need to double check here because the notification may be @@ -1608,7 +1636,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() } extTW, newSyn := e.rcv.handleTimeWaitSegment(s) if newSyn { - info := e.EndpointInfo.TransportEndpointInfo + info := e.TransportEndpointInfo newID := info.ID newID.RemoteAddress = "" newID.RemotePort = 0 @@ -1620,7 +1648,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} } for _, netProto := range netProtos { - if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil { tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { @@ -1683,7 +1711,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { for { e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() switch v { case newSegment: diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 98aecab9e..21162f01a 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -172,10 +172,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newSegment(r, id, pkt) - if !s.parse() { + + s := newIncomingSegment(id, pkt) + if !s.parse(pkt.RXTransportChecksumValidated) { ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 560b4904c..1d1b01a6c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) { testV6Connect(t, c) } +func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { + c := context.NewWithOpts(t, context.Options{ + EnableV6: true, + MTU: defaultMTU, + }) + defer c.Cleanup() + + // Create a v6 endpoint but don't set the v6-only TCP option. + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -386,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) { } } - // Make sure we get the same error when calling the original ep and the - // new one. This validates that v4-mapped endpoints are still able to - // query the V6Only flag, whereas pure v4 endpoints are not. - _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { - t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) - } - // Check the peer address. addr, err := nep.GetRemoteAddress() if err != nil { @@ -511,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress - nep, _, err := c.EP.Accept(&addr) + _, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept(&addr) + _, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -529,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) { if addr.Addr != context.TestV6Addr { t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) } - - // Make sure we can still query the v6 only status of the new endpoint, - // that is, that it is in fact a v6 socket. - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) - } } func TestV4AcceptOnV4(t *testing.T) { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3bcd3923a..7a37c10bb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -310,16 +310,12 @@ type Stats struct { func (*Stats) IsEndpointStats() {} // EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. +// can be queried by monitoring tools. This exists to allow tcp-only state to +// be exposed. // // +stateify savable type EndpointInfo struct { stack.TransportEndpointInfo - - // HardError is meaningful only when state is stateError. It stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. HardError is protected by endpoint mu. - HardError *tcpip.Error `state:".(string)"` } // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -367,6 +363,7 @@ func (*EndpointInfo) IsEndpointInfo() {} // +stateify savable type endpoint struct { EndpointInfo + tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the // a given tcp processor goroutine. @@ -386,6 +383,11 @@ type endpoint struct { waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 + // hardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. hardError is protected by endpoint mu. + hardError *tcpip.Error `state:".(string)"` + // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. lastErrorMu sync.Mutex `state:"nosave"` @@ -421,7 +423,10 @@ type endpoint struct { // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. - mu sync.Mutex `state:"nosave"` + // + // During handshake, mu is locked by the protocol listen goroutine and + // released by the handshake completion goroutine. + mu sync.CrossGoroutineMutex `state:"nosave"` ownedByUser uint32 // state must be read/set using the EndpointState()/setEndpointState() @@ -436,13 +441,14 @@ type endpoint struct { isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` boundNICID tcpip.NICID - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 - v6only bool isConnectNotified bool - // TCP should never broadcast but Linux nevertheless supports enabling/ - // disabling SO_BROADCAST, albeit as a NOOP. - broadcast bool + + // h stores a reference to the current handshake state if the endpoint is in + // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep. + // nil otherwise. + h *handshake `state:"nosave"` // portFlags stores the current values of port related flags. portFlags ports.Flags @@ -504,24 +510,9 @@ type endpoint struct { // delay is a boolean (0 is false) and must be accessed atomically. delay uint32 - // cork holds back segments until full. - // - // cork is a boolean (0 is false) and must be accessed atomically. - cork uint32 - // scoreboard holds TCP SACK Scoreboard information for this endpoint. scoreboard *SACKScoreboard - // The options below aren't implemented, but we remember the user - // settings because applications expect to be able to set/query these - // options. - - // slowAck holds the negated state of quick ack. It is stubbed out and - // does nothing. - // - // slowAck is a boolean (0 is false) and must be accessed atomically. - slowAck uint32 - // segmentQueue is used to hand received segments to the protocol // goroutine. Segments are queued as long as the queue is not full, // and dropped when it is. @@ -683,8 +674,8 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption + // ops is used to get socket level options. + ops tcpip.SocketOptions } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -696,7 +687,7 @@ func (e *endpoint) UniqueID() uint64 { // // If userMSS is non-zero and is not greater than the maximum possible MSS for // r, it will be used; otherwise, the maximum possible MSS will be used. -func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { +func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 { // The maximum possible MSS is dependent on the route. // TODO(b/143359391): Respect TCP Min and Max size. maxMSS := uint16(r.MTU() - header.TCPMinimumSize) @@ -721,9 +712,9 @@ func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned // by another user goroutine if not then we spin, otherwise - // we just goto sleep on the Lock() and wait. + // we just go to sleep on the Lock() and wait. if !e.mu.TryLock() { - // If socket is owned by the user then just goto sleep + // If socket is owned by the user then just go to sleep // as the lock could be held for a reasonably long time. if atomic.LoadUint32(&e.ownedByUser) == 1 { e.mu.Lock() @@ -845,7 +836,6 @@ func (e *endpoint) recentTimestamp() uint32 { // +stateify savable type keepalive struct { sync.Mutex `state:"nosave"` - enabled bool idle time.Duration interval time.Duration count int @@ -879,6 +869,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } + e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) + e.ops.SetQuickAck(true) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -902,7 +895,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { - e.SetSockOptBool(tcpip.DelayOption, true) + e.ops.SetDelayOption(true) } var tcpLT tcpip.TCPLingerTimeoutOption @@ -922,6 +915,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) return e } @@ -1043,7 +1037,8 @@ func (e *endpoint) Close() { return } - if e.linger.Enabled && e.linger.Timeout == 0 { + linger := e.SocketOptions().GetLinger() + if linger.Enabled && linger.Timeout == 0 { s := e.EndpointState() isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv if isResetState { @@ -1069,9 +1064,7 @@ func (e *endpoint) Close() { e.closeNoShutdownLocked() } -// closeNoShutdown closes the endpoint without doing a full shutdown. This is -// used when a connection needs to be aborted with a RST and we want to skip -// a full 4 way TCP shutdown. +// closeNoShutdown closes the endpoint without doing a full shutdown. func (e *endpoint) closeNoShutdownLocked() { // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also @@ -1098,6 +1091,7 @@ func (e *endpoint) closeNoShutdownLocked() { return } + eventMask := waiter.EventIn | waiter.EventOut // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. if e.workerRunning { @@ -1109,8 +1103,12 @@ func (e *endpoint) closeNoShutdownLocked() { } else { e.transitionToStateCloseLocked() // Notify that the endpoint is closed. - e.waiterQueue.Notify(waiter.EventHUp) + eventMask |= waiter.EventHUp } + + // The TCP closing state-machine would eventually notify EventHUp, but we + // notify EventIn|EventOut immediately to unblock any blocked waiters. + e.waiterQueue.Notify(eventMask) } // closePendingAcceptableConnections closes all connections that have completed @@ -1143,6 +1141,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. e.closePendingAcceptableConnectionsLocked() + e.keepalive.timer.cleanup() e.workerCleanup = false @@ -1159,7 +1158,11 @@ func (e *endpoint) cleanupLocked() { e.boundPortFlags = ports.Flags{} e.boundDest = tcpip.FullAddress{} - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -1269,11 +1272,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -func (e *endpoint) LastError() *tcpip.Error { +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) hardErrorLocked() *tcpip.Error { + err := e.hardError + e.hardError = nil + return err +} + +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) lastErrorLocked() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1281,6 +1293,16 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// LastError implements tcpip.Endpoint.LastError. +func (e *endpoint) LastError() *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + if err := e.hardErrorLocked(); err != nil { + return err + } + return e.lastErrorLocked() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1302,9 +1324,11 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, bufUsed := e.rcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() - he := e.HardError if s == StateError { - return buffer.View{}, tcpip.ControlMessages{}, he + if err := e.hardErrorLocked(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } e.stats.ReadErrors.NotConnected.Increment() return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected @@ -1360,9 +1384,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { + // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: - return 0, e.HardError + if err := e.hardErrorLocked(); err != nil { + return 0, err + } + return 0, tcpip.ErrClosedForSend case !s.connecting() && !s.connected(): return 0, tcpip.ErrClosedForSend case s.connecting(): @@ -1425,7 +1453,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) + s := newOutgoingSegment(e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -1468,7 +1496,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1476,10 +1504,10 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.HardError + return 0, e.hardErrorLocked() } e.stats.ReadErrors.InvalidEndpointState.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } e.rcvListMu.Lock() @@ -1488,9 +1516,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + return 0, tcpip.ErrClosedForReceive } - return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + return 0, tcpip.ErrWouldBlock } // Make a copy of vec so we can modify the slide headers. @@ -1505,7 +1533,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro for len(v) > 0 { if len(vec) == 0 { - return num, tcpip.ControlMessages{}, nil + return num, nil } if len(vec[0]) == 0 { vec = vec[1:] @@ -1520,7 +1548,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } } - return num, tcpip.ControlMessages{}, nil + return num, nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1592,77 +1620,39 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo return false, false } -// SetSockOptBool sets a socket option. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - - case tcpip.BroadcastOption: - e.LockUser() - e.broadcast = v - e.UnlockUser() - - case tcpip.CorkOption: - e.LockUser() - if !v { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - e.UnlockUser() - - case tcpip.DelayOption: - if v { - atomic.StoreUint32(&e.delay, 1) - } else { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - - case tcpip.QuickAckOption: - o := uint32(1) - if v { - o = 0 - } - atomic.StoreUint32(&e.slowAck, o) - - case tcpip.ReuseAddressOption: - e.LockUser() - e.portFlags.TupleOnly = v - e.UnlockUser() - - case tcpip.ReusePortOption: - e.LockUser() - e.portFlags.LoadBalanced = v - e.UnlockUser() +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.LockUser() + e.portFlags.TupleOnly = v + e.UnlockUser() +} - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.LockUser() + e.portFlags.LoadBalanced = v + e.UnlockUser() +} - // We only allow this to be set when we're in the initial state. - if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState - } +// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. +func (e *endpoint) OnKeepAliveSet(v bool) { + e.notifyProtocolGoroutine(notifyKeepaliveChanged) +} - e.LockUser() - e.v6only = v - e.UnlockUser() +// OnDelayOptionSet implements tcpip.SocketOptionsHandler.OnDelayOptionSet. +func (e *endpoint) OnDelayOptionSet(v bool) { + if !v { + // Handle delayed data. + e.sndWaker.Assert() } +} - return nil +// OnCorkOptionSet implements tcpip.SocketOptionsHandler.OnCorkOptionSet. +func (e *endpoint) OnCorkOptionSet(v bool) { + if !v { + // Handle the corked data. + e.sndWaker.Assert() + } } // SetSockOptInt sets a socket option. @@ -1846,9 +1836,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - case *tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(*v) @@ -1917,11 +1904,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.LockUser() - e.linger = *v - e.UnlockUser() - default: return nil } @@ -1944,66 +1926,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.BroadcastOption: - e.LockUser() - v := e.broadcast - e.UnlockUser() - return v, nil - - case tcpip.CorkOption: - return atomic.LoadUint32(&e.cork) != 0, nil - - case tcpip.DelayOption: - return atomic.LoadUint32(&e.delay) != 0, nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - return v, nil - - case tcpip.QuickAckOption: - v := atomic.LoadUint32(&e.slowAck) == 0 - return v, nil - - case tcpip.ReuseAddressOption: - e.LockUser() - v := e.portFlags.TupleOnly - e.UnlockUser() - - return v, nil - - case tcpip.ReusePortOption: - e.LockUser() - v := e.portFlags.LoadBalanced - e.UnlockUser() - - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.LockUser() - v := e.v6only - e.UnlockUser() - - return v, nil - - case tcpip.MulticastLoopOption: - return true, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -2114,10 +2036,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - *o = 1 - case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc @@ -2146,11 +2064,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { Port: port, } - case *tcpip.LingerOption: - e.LockUser() - *o = e.linger - e.UnlockUser() - default: return tcpip.ErrUnknownProtocolOption } @@ -2160,7 +2073,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -2176,6 +2089,8 @@ func (*endpoint) Disconnect() *tcpip.Error { func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { err := e.connect(addr, true, true) if err != nil && !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() } @@ -2235,7 +2150,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnecting case StateError: - return e.HardError + if err := e.hardErrorLocked(); err != nil { + return err + } + return tcpip.ErrConnectionAborted default: return tcpip.ErrInvalidEndpointState @@ -2310,7 +2228,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // done yet) or the reservation was freed between the check above and // the FindTransportEndpoint below. But rather than retry the same port // we just skip it and move on. - transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID()) if transEP == nil { // ReservePort failed but there is no registered endpoint with // demuxer. Which indicates there is at least some endpoint that has @@ -2379,7 +2297,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.ID - s.route = r.Clone() e.sndWaker.Assert() } } @@ -2389,14 +2306,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if run { - e.workerRunning = true - e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + if err := e.startMainLoop(handshake); err != nil { + return err + } } return tcpip.ErrConnectStarted } +// startMainLoop sends the initial SYN and starts the main loop for the +// endpoint. +func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { + preloop := func() *tcpip.Error { + if handshake { + h := e.newHandshake() + e.setEndpointState(StateSynSent) + if err := h.start(); err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.setEndpointState(StateError) + e.hardError = err + + // Call cleanupLocked to free up any reservations. + e.cleanupLocked() + return err + } + } + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + return nil + } + + if e.route.IsResolutionRequired() { + // If the endpoint is closed between releasing e.mu and the goroutine below + // acquiring it, make sure that cleanup is deferred to the new goroutine. + e.workerRunning = true + + // Sending the initial SYN may block due to route resolution; do it in a + // separate goroutine to avoid blocking the syscall goroutine. + go func() { // S/R-SAFE: will be drained before save. + e.mu.Lock() + if err := preloop(); err != nil { + e.workerRunning = false + e.mu.Unlock() + return + } + e.mu.Unlock() + _ = e.protocolMainLoop(handshake, nil) + }() + return nil + } + + // No route resolution is required, so we can send the initial SYN here without + // blocking. This will hopefully reduce overall latency by overlapping time + // spent waiting for a SYN-ACK and time spent spinning up a new goroutine + // for the main loop. + if err := preloop(); err != nil { + return err + } + e.workerRunning = true + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + return nil +} + // ConnectEndpoint is not supported. func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { return tcpip.ErrInvalidEndpointState @@ -2445,7 +2418,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.ID, nil) + s := newOutgoingSegment(e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ // Mark endpoint as closed. @@ -2627,14 +2600,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + + // Expand netProtos to include v4 and v6 under dual-stack if the caller is + // binding to a wildcard (empty) address, and this is an IPv6 endpoint with + // v6only set to false. + if netProto == header.IPv6ProtocolNumber { + stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) + alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4 + if alsoBindToV4 { + netProtos = append(netProtos, header.IPv4ProtocolNumber) } } @@ -2715,9 +2690,9 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { } } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first - // land at the Dispatcher which then can either delivery using the + // land at the Dispatcher which then can either deliver using the // worker go routine or directly do the invoke the tcp processing inline // based on the state of the endpoint. } @@ -3051,6 +3026,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { FACK: rc.fack, RTT: rc.rtt, Reord: rc.reorderSeen, + DSACKSeen: rc.dsackSeen, } return s } @@ -3074,9 +3050,9 @@ func (e *endpoint) initHardwareGSO() { } func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() - } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + } else if e.route.HasSoftwareGSOCapability() { e.gso = &stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, @@ -3095,7 +3071,7 @@ func (e *endpoint) State() uint32 { func (e *endpoint) Info() tcpip.EndpointInfo { e.LockUser() // Make a copy of the endpoint info. - ret := e.EndpointInfo + ret := e.TransportEndpointInfo e.UnlockUser() return &ret } @@ -3120,3 +3096,8 @@ func (e *endpoint) Wait() { <-notifyCh } } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b25431467..ba67176b5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() { switch { case epState == StateInitial || epState == StateBound: case epState.connected() || epState.handshake(): - if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { - if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { + if !e.route.HasSaveRestoreCapability() { + if !e.route.HasDisconncetOkCapability() { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) @@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() { // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -320,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { } // saveHardError is invoked by stateify. -func (e *EndpointInfo) saveHardError() string { - if e.HardError == nil { +func (e *endpoint) saveHardError() string { + if e.hardError == nil { return "" } - return e.HardError.String() + return e.hardError.String() } // loadHardError is invoked by stateify. -func (e *EndpointInfo) loadHardError(s string) { +func (e *endpoint) loadHardError(s string) { if s == "" { return } - e.HardError = tcpip.StringToError(s) + e.hardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 070b634b4..596178625 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -30,6 +30,8 @@ import ( // The canonical way of using it is to pass the Forwarder.HandlePacket function // to stack.SetTransportProtocolHandler. type Forwarder struct { + stack *stack.Stack + maxInFlight int handler func(*ForwarderRequest) @@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward rcvWnd = DefaultReceiveBufferSize } return &Forwarder{ + stack: s, maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), @@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newSegment(r, id, pkt) +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newIncomingSegment(id, pkt) defer s.decRef() // We only care about well-formed SYN packets. - if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn { return false } @@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) { delete(r.forwarder.inFlight, r.segment.id) r.forwarder.mu.Unlock() - // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) + replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */) } // Release all resources. @@ -150,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } f := r.forwarder - ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{ MSS: r.synOptions.MSS, WS: r.synOptions.WS, TS: r.synOptions.TS, diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 5bce73605..672159eed 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(r, ep, id, pkt) +func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + p.dispatcher.queuePacket(ep, id, pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." - -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newSegment(r, id, pkt) +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + s := newIncomingSegment(id, pkt) defer s.decRef() - if !s.parse() || !s.csumValid { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } if !s.flagIsSet(header.TCPFlagRst) { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(p.stack, s, stack.DefaultTOS, 0) } return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment, tos, ttl uint8) { +// +// If the passed TTL is 0, then the route's default TTL will be used. +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { + route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) @@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) { flags |= header.TCPFlagAck ack = s.sequenceNumber.Add(s.logicalLen()) } - sendTCP(&s.route, tcpFields{ + + if ttl == 0 { + ttl = route.DefaultTTL() + } + + return sendTCP(route, tcpFields{ id: s.id, ttl: ttl, tos: tos, diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index d312b1b8b..e0a50a919 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -29,12 +29,12 @@ import ( // // +stateify savable type rackControl struct { + // dsackSeen indicates if the connection has seen a DSACK. + dsackSeen bool + // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value - // dsack indicates if the connection has seen a DSACK. - dsack bool - // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value @@ -122,3 +122,8 @@ func (rc *rackControl) detectReorder(seg *segment) { rc.reorderSeen = true } } + +// setDSACKSeen updates rack control if duplicate SACK is seen by the connection. +func (rc *rackControl) setDSACKSeen() { + rc.dsackSeen = true +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 8e0b7c843..f2b1b68da 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -16,6 +16,7 @@ package tcp import ( "container/heap" + "math" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -48,6 +49,10 @@ type receiver struct { rcvWndScale uint8 + // prevBufused is the snapshot of endpoint rcvBufUsed taken when we + // advertise a receive window. + prevBufUsed int + closed bool // pendingRcvdSegments is bounded by the receive buffer size of the @@ -80,9 +85,9 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // outgoing packets, we should use what we have advertised for acceptability // test. scaledWindowSize := r.rcvWnd >> r.rcvWndScale - if scaledWindowSize > 0xffff { + if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. - scaledWindowSize = 0xffff + scaledWindowSize = math.MaxUint16 } advertisedWindowSize := scaledWindowSize << r.rcvWndScale return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) @@ -106,6 +111,34 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) { func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() + unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + bufUsed := r.ep.receiveBufferUsed() + + // Grow the right edge of the window only for payloads larger than the + // the segment overhead OR if the application is actively consuming data. + // + // Avoiding growing the right edge otherwise, addresses a situation below: + // An application has been slow in reading data and we have burst of + // incoming segments lengths < segment overhead. Here, our available free + // memory would reduce drastically when compared to the advertised receive + // window. + // + // For example: With incoming 512 bytes segments, segment overhead of + // 552 bytes (at the time of writing this comment), with receive window + // starting from 1MB and with rcvAdvWndScale being 1, buffer would reach 0 + // when the curWnd is still 19436 bytes, because for every incoming segment + // newWnd would reduce by (552+512) >> rcvAdvWndScale (current value 1), + // while curWnd would reduce by 512 bytes. + // Such a situation causes us to keep tail dropping the incoming segments + // and never advertise zero receive window to the peer. + // + // Linux does a similar check for minimal sk_buff size (128): + // https://github.com/torvalds/linux/blob/d5beb3140f91b1c8a3d41b14d729aefa4dcc58bc/net/ipv4/tcp_input.c#L783 + // + // Also, if the application is reading the data, we keep growing the right + // edge, as we are still advertising a window that we think can be serviced. + toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed + // Update rcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we @@ -115,7 +148,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP rcvNxt rcvAcc new rcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) { + if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { // If the new window moves the right edge, then update rcvAcc. r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) } else { @@ -130,11 +163,24 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // receiver's estimated RTT. r.rcvWnd = newWnd r.rcvWUP = r.rcvNxt + r.prevBufUsed = bufUsed scaledWnd := r.rcvWnd >> r.rcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } + + // If we started off with a window larger than what can he held in + // the 16bit window field, we ceil the value to the max value. + // While ceiling, we still do not want to grow the right edge when + // not applicable. + if scaledWnd > math.MaxUint16 { + if toGrow { + scaledWnd = seqnum.Size(math.MaxUint16) + } else { + scaledWnd = seqnum.Size(uint16(scaledWnd)) + } + } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go new file mode 100644 index 000000000..2aa708e97 --- /dev/null +++ b/pkg/tcpip/transport/tcp/reno_recovery.go @@ -0,0 +1,67 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// renoRecovery stores the variables related to TCP Reno loss recovery +// algorithm. +// +// +stateify savable +type renoRecovery struct { + s *sender +} + +func newRenoRecovery(s *sender) *renoRecovery { + return &renoRecovery{s: s} +} + +func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { + ack := rcvdSeg.ackNumber + snd := rr.s + + // We are in fast recovery mode. Ignore the ack if it's out of range. + if !ack.InRange(snd.sndUna, snd.sndNxt+1) { + return + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window { + return + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if !fastRetransmit && ack == snd.fr.first { + // We received a dup, inflate the congestion window by 1 packet + // if we're not at the max yet. Only inflate the window if + // regular FastRecovery is in use, RFC6675 does not require + // inflating cwnd on duplicate ACKs. + if snd.sndCwnd < snd.fr.maxCwnd { + snd.sndCwnd++ + } + return + } + + // A partial ack was received. Retransmit this packet and remember it + // so that we don't retransmit it again. + // + // We don't inflate the window because we're putting the same packet + // back onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + snd.fr.first = ack + snd.dupAckCount = 0 + snd.resendSegment() +} diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go new file mode 100644 index 000000000..7e813fa96 --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack_recovery.go @@ -0,0 +1,120 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import "gvisor.dev/gvisor/pkg/tcpip/seqnum" + +// sackRecovery stores the variables related to TCP SACK loss recovery +// algorithm. +// +// +stateify savable +type sackRecovery struct { + s *sender +} + +func newSACKRecovery(s *sender) *sackRecovery { + return &sackRecovery{s: s} +} + +// handleSACKRecovery implements the loss recovery phase as described in RFC6675 +// section 5, step C. +func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { + snd := sr.s + snd.SetPipe() + + if smss := int(snd.ep.scoreboard.SMSS()); limit > smss { + // Cap segment size limit to s.smss as SACK recovery requires + // that all retransmissions or new segments send during recovery + // be of <= SMSS. + limit = smss + } + + nextSegHint := snd.writeList.Front() + for snd.outstanding < snd.sndCwnd { + var nextSeg *segment + var rescueRtx bool + nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint) + if nextSeg == nil { + return dataSent + } + if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + // New data being sent. + + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. + // + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + // + // We pass s.smss as the limit as the Step 2) requires that + // new data sent should be of size s.smss or less. + if sent := snd.maybeSendSegment(nextSeg, limit, end); !sent { + return dataSent + } + dataSent = true + snd.outstanding++ + snd.writeNext = nextSeg.Next() + continue + } + + // Now handle the retransmission case where we matched either step 1,3 or 4 + // of the NextSeg algorithm. + // RFC 6675, Step C.4. + // + // "The estimate of the amount of data outstanding in the network + // must be updated by incrementing pipe by the number of octets + // transmitted in (C.1)." + snd.outstanding++ + dataSent = true + snd.sendSegment(nextSeg) + + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if rescueRtx { + // We do the last part of rule (4) of NextSeg here to update + // RescueRxt as until this point we don't know if we are going + // to use the rescue transmission. + snd.fr.rescueRxt = snd.fr.last + } else { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + snd.fr.highRxt = segEnd - 1 + } + } + return dataSent +} + +func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { + snd := sr.s + if fastRetransmit { + snd.resendSegment() + } + + // We are in fast recovery mode. Ignore the ack if it's out of range. + if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) { + return + } + + // RFC 6675 recovery algorithm step C 1-5. + end := snd.sndUna.Add(snd.sndWnd) + dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end) + snd.postXmit(dataSent) +} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go index 7ef2df377..833a7b470 100644 --- a/pkg/tcpip/transport/tcp/sack_scoreboard.go +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { return found } -// Dump prints the state of the scoreboard structure. +// String returns human-readable state of the scoreboard structure. func (s *SACKScoreboard) String() string { var str strings.Builder str.WriteString("SACKScoreboard: {") diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 1f9c5cf50..5ef73ec74 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -45,9 +46,18 @@ type segment struct { ep *endpoint qFlags queueFlags id stack.TransportEndpointID `state:"manual"` - route stack.Route `state:"manual"` - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - hdr header.TCP + + // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of + // individual members for link/network packet info. + srcAddr tcpip.Address + dstAddr tcpip.Address + netProto tcpip.NetworkProtocolNumber + nicID tcpip.NICID + remoteLinkAddr tcpip.LinkAddress + + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -76,11 +86,16 @@ type segment struct { acked bool } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { + netHdr := pkt.Network() s := &segment{ - refCnt: 1, - id: id, - route: r.Clone(), + refCnt: 1, + id: id, + srcAddr: netHdr.SourceAddress(), + dstAddr: netHdr.DestinationAddress(), + netProto: pkt.NetworkProtocolNumber, + nicID: pkt.NICID, + remoteLinkAddr: pkt.SourceLinkAddress(), } s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) @@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB return s } -func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, - route: r.Clone(), } s.rcvdTime = time.Now() if len(v) != 0 { @@ -110,7 +124,9 @@ func (s *segment) clone() *segment { ackNumber: s.ackNumber, flags: s.flags, window: s.window, - route: s.route.Clone(), + netProto: s.netProto, + nicID: s.nicID, + remoteLinkAddr: s.remoteLinkAddr, viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, @@ -160,7 +176,6 @@ func (s *segment) decRef() { panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) } } - s.route.Release() } } @@ -189,7 +204,7 @@ func (s *segment) payloadSize() int { // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { - return segSize + s.data.Size() + return SegSize + s.data.Size() } // parse populates the sequence & ack numbers, flags, and window fields of the @@ -198,10 +213,10 @@ func (s *segment) segMemSize() int { // // Returns boolean indicating if the parsing was successful. // -// If checksum verification is not offloaded then parse also verifies the +// If checksum verification may not be skipped, parse also verifies the // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. -func (s *segment) parse() bool { +func (s *segment) parse(skipChecksumValidation bool) bool { // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -220,16 +235,14 @@ func (s *segment) parse() bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - // Query the link capabilities to decide if checksum validation is - // required. verifyChecksum := true - if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { + if skipChecksumValidation { s.csumValid = true verifyChecksum = false } if verifyChecksum { s.csum = s.hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go index 0ab7b8f56..392ff0859 100644 --- a/pkg/tcpip/transport/tcp/segment_unsafe.go +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -19,5 +19,6 @@ import ( ) const ( - segSize = int(unsafe.Sizeof(segment{})) + // SegSize is the minimal size of the segment overhead. + SegSize = int(unsafe.Sizeof(segment{})) ) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 6fa8d63cd..baec762e1 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -18,7 +18,6 @@ import ( "fmt" "math" "sort" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -92,6 +91,17 @@ type congestionControl interface { PostRecovery() } +// lossRecovery is an interface that must be implemented by any supported +// loss recovery algorithm. +type lossRecovery interface { + // DoRecovery is invoked when loss is detected and segments need + // to be retransmitted. The cumulative or selective ACK is passed along + // with the flag which identifies whether the connection entered fast + // retransmit with this ACK and to retransmit the first unacknowledged + // segment. + DoRecovery(rcvdSeg *segment, fastRetransmit bool) +} + // sender holds the state necessary to send TCP segments. // // +stateify savable @@ -108,6 +118,9 @@ type sender struct { // fr holds state related to fast recovery. fr fastRecovery + // lr is the loss recovery algorithm used by the sender. + lr lossRecovery + // sndCwnd is the congestion window, in packets. sndCwnd int @@ -276,6 +289,8 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.cc = s.initCongestionControl(ep.cc) + s.lr = s.initLossRecovery() + // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { @@ -330,6 +345,14 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon } } +// initLossRecovery initiates the loss recovery algorithm for the sender. +func (s *sender) initLossRecovery() lossRecovery { + if s.ep.sackPermitted { + return newSACKRecovery(s) + } + return newRenoRecovery(s) +} + // updateMaxPayloadSize updates the maximum payload size based on the given // MTU. If this is in response to "packet too big" control packets (indicated // by the count argument), it also reduces the number of outstanding packets and @@ -550,7 +573,7 @@ func (s *sender) retransmitTimerExpired() bool { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it // has already been updated when entered fast-recovery. - s.leaveFastRecovery() + s.leaveRecovery() } s.state = RTORecovery @@ -789,7 +812,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } if !nextTooBig && seg.data.Size() < available { // Segment is not full. - if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { + if s.outstanding > 0 && s.ep.ops.GetDelayOption() { // Nagle's algorithm. From Wikipedia: // Nagle's algorithm works by // combining a number of small @@ -808,7 +831,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // send space and MSS. // TODO(gvisor.dev/issue/2833): Drain the held segments after a // timeout. - if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { + if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() { return false } } @@ -913,79 +936,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se return true } -// handleSACKRecovery implements the loss recovery phase as described in RFC6675 -// section 5, step C. -func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { - s.SetPipe() - - if smss := int(s.ep.scoreboard.SMSS()); limit > smss { - // Cap segment size limit to s.smss as SACK recovery requires - // that all retransmissions or new segments send during recovery - // be of <= SMSS. - limit = smss - } - - nextSegHint := s.writeList.Front() - for s.outstanding < s.sndCwnd { - var nextSeg *segment - var rescueRtx bool - nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint) - if nextSeg == nil { - return dataSent - } - if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) { - // New data being sent. - - // Step C.3 described below is handled by - // maybeSendSegment which increments sndNxt when - // a segment is transmitted. - // - // Step C.3 "If any of the data octets sent in - // (C.1) are above HighData, HighData must be - // updated to reflect the transmission of - // previously unsent data." - // - // We pass s.smss as the limit as the Step 2) requires that - // new data sent should be of size s.smss or less. - if sent := s.maybeSendSegment(nextSeg, limit, end); !sent { - return dataSent - } - dataSent = true - s.outstanding++ - s.writeNext = nextSeg.Next() - continue - } - - // Now handle the retransmission case where we matched either step 1,3 or 4 - // of the NextSeg algorithm. - // RFC 6675, Step C.4. - // - // "The estimate of the amount of data outstanding in the network - // must be updated by incrementing pipe by the number of octets - // transmitted in (C.1)." - s.outstanding++ - dataSent = true - s.sendSegment(nextSeg) - - segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) - if rescueRtx { - // We do the last part of rule (4) of NextSeg here to update - // RescueRxt as until this point we don't know if we are going - // to use the rescue transmission. - s.fr.rescueRxt = s.fr.last - } else { - // RFC 6675, Step C.2 - // - // "If any of the data octets sent in (C.1) are below - // HighData, HighRxt MUST be set to the highest sequence - // number of the retransmitted segment unless NextSeg () - // rule (4) was invoked for this retransmission." - s.fr.highRxt = segEnd - 1 - } - } - return dataSent -} - func (s *sender) sendZeroWindowProbe() { ack, win := s.ep.rcv.getSendParams() s.unackZeroWindowProbes++ @@ -1014,6 +964,30 @@ func (s *sender) disableZeroWindowProbing() { s.resendTimer.disable() } +func (s *sender) postXmit(dataSent bool) { + if dataSent { + // We sent data, so we should stop the keepalive timer to ensure + // that no keepalives are sent while there is pending data. + s.ep.disableKeepaliveTimer() + } + + // If the sender has advertized zero receive window and we have + // data to be sent out, start zero window probing to query the + // the remote for it's receive window size. + if s.writeNext != nil && s.sndWnd == 0 { + s.enableZeroWindowProbing() + } + + // Enable the timer if we have pending data and it's not enabled yet. + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + s.resendTimer.enable(s.rto) + } + // If we have no more pending data, start the keepalive timer. + if s.sndUna == s.sndNxt { + s.ep.resetKeepaliveTimer(false) + } +} + // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { @@ -1034,55 +1008,29 @@ func (s *sender) sendData() { } var dataSent bool - - // RFC 6675 recovery algorithm step C 1-5. - if s.fr.active && s.ep.sackPermitted { - dataSent = s.handleSACKRecovery(s.maxPayloadSize, end) - } else { - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize - if cwndLimit < limit { - limit = cwndLimit - } - if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - // Move writeNext along so that we don't try and scan data that - // has already been SACKED. - s.writeNext = seg.Next() - continue - } - if sent := s.maybeSendSegment(seg, limit, end); !sent { - break - } - dataSent = true - s.outstanding += s.pCount(seg) + for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + if cwndLimit < limit { + limit = cwndLimit + } + if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Move writeNext along so that we don't try and scan data that + // has already been SACKED. s.writeNext = seg.Next() + continue } + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding += s.pCount(seg) + s.writeNext = seg.Next() } - if dataSent { - // We sent data, so we should stop the keepalive timer to ensure - // that no keepalives are sent while there is pending data. - s.ep.disableKeepaliveTimer() - } - - // If the sender has advertized zero receive window and we have - // data to be sent out, start zero window probing to query the - // the remote for it's receive window size. - if s.writeNext != nil && s.sndWnd == 0 { - s.enableZeroWindowProbing() - } - - // Enable the timer if we have pending data and it's not enabled yet. - if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { - s.resendTimer.enable(s.rto) - } - // If we have no more pending data, start the keepalive timer. - if s.sndUna == s.sndNxt { - s.ep.resetKeepaliveTimer(false) - } + s.postXmit(dataSent) } -func (s *sender) enterFastRecovery() { +func (s *sender) enterRecovery() { s.fr.active = true // Save state to reflect we're now in fast recovery. // @@ -1104,7 +1052,7 @@ func (s *sender) enterFastRecovery() { s.ep.stack.Stats().TCP.FastRecovery.Increment() } -func (s *sender) leaveFastRecovery() { +func (s *sender) leaveRecovery() { s.fr.active = false s.fr.maxCwnd = 0 s.dupAckCount = 0 @@ -1115,57 +1063,6 @@ func (s *sender) leaveFastRecovery() { s.cc.PostRecovery() } -func (s *sender) handleFastRecovery(seg *segment) (rtx bool) { - ack := seg.ackNumber - // We are in fast recovery mode. Ignore the ack if it's out of - // range. - if !ack.InRange(s.sndUna, s.sndNxt+1) { - return false - } - - // Leave fast recovery if it acknowledges all the data covered by - // this fast recovery session. - if s.fr.last.LessThan(ack) { - s.leaveFastRecovery() - return false - } - - if s.ep.sackPermitted { - // When SACK is enabled we let retransmission be governed by - // the SACK logic. - return false - } - - // Don't count this as a duplicate if it is carrying data or - // updating the window. - if seg.logicalLen() != 0 || s.sndWnd != seg.window { - return false - } - - // Inflate the congestion window if we're getting duplicate acks - // for the packet we retransmitted. - if ack == s.fr.first { - // We received a dup, inflate the congestion window by 1 packet - // if we're not at the max yet. Only inflate the window if - // regular FastRecovery is in use, RFC6675 does not require - // inflating cwnd on duplicate ACKs. - if s.sndCwnd < s.fr.maxCwnd { - s.sndCwnd++ - } - return false - } - - // A partial ack was received. Retransmit this packet and - // remember it so that we don't retransmit it again. We don't - // inflate the window because we're putting the same packet back - // onto the wire. - // - // N.B. The retransmit timer will be reset by the caller. - s.fr.first = ack - s.dupAckCount = 0 - return true -} - // isAssignedSequenceNumber relies on the fact that we only set flags once a // sequencenumber is assigned and that is only done right before we send the // segment. As a result any segment that has a non-zero flag has a valid @@ -1228,14 +1125,11 @@ func (s *sender) SetPipe() { s.outstanding = pipe } -// checkDuplicateAck is called when an ack is received. It manages the state -// related to duplicate acks and determines if a retransmit is needed according -// to the rules in RFC 6582 (NewReno). -func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { +// detectLoss is called when an ack is received and returns whether a loss is +// detected. It manages the state related to duplicate acks and determines if +// a retransmit is needed according to the rules in RFC 6582 (NewReno). +func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { ack := seg.ackNumber - if s.fr.active { - return s.handleFastRecovery(seg) - } // We're not in fast recovery yet. A segment is considered a duplicate // only if it doesn't carry any data and doesn't update the send window, @@ -1266,14 +1160,14 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 // // We only do the check here, the incrementing of last to the highest - // sequence number transmitted till now is done when enterFastRecovery + // sequence number transmitted till now is done when enterRecovery // is invoked. if !s.fr.last.LessThan(seg.ackNumber) { s.dupAckCount = 0 return false } s.cc.HandleNDupAcks() - s.enterFastRecovery() + s.enterRecovery() s.dupAckCount = 0 return true } @@ -1285,21 +1179,29 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + // Look for DSACK block. + idx := 0 + n := len(rcvdSeg.parsedOptions.SACKBlocks) + if s.checkDSACK(rcvdSeg) { + s.rc.setDSACKSeen() + idx = 1 + n-- + } + + if n == 0 { + return + } + // Sort the SACK blocks. The first block is the most recent unacked // block. The following blocks can be in arbitrary order. - sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) - copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks) + sackBlocks := make([]header.SACKBlock, n) + copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks[idx:]) sort.Slice(sackBlocks, func(i, j int) bool { return sackBlocks[j].Start.LessThan(sackBlocks[i].Start) }) seg := s.writeList.Front() for _, sb := range sackBlocks { - // This check excludes DSACK blocks. - if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) { - continue - } - for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { s.rc.update(seg, rcvdSeg, s.ep.tsOffset) @@ -1311,6 +1213,50 @@ func (s *sender) walkSACK(rcvdSeg *segment) { } } +// checkDSACK checks if a DSACK is reported and updates it in RACK. +func (s *sender) checkDSACK(rcvdSeg *segment) bool { + n := len(rcvdSeg.parsedOptions.SACKBlocks) + if n == 0 { + return false + } + + sb := rcvdSeg.parsedOptions.SACKBlocks[0] + // Check if SACK block is invalid. + if sb.End.LessThan(sb.Start) { + return false + } + + // See: https://tools.ietf.org/html/rfc2883#section-5 DSACK is sent in + // at most one SACK block. DSACK is detected in the below two cases: + // * If the SACK sequence space is less than this cumulative ACK, it is + // an indication that the segment identified by the SACK block has + // been received more than once by the receiver. + // * If the sequence space in the first SACK block is greater than the + // cumulative ACK, then the sender next compares the sequence space + // in the first SACK block with the sequence space in the second SACK + // block, if there is one. This comparison can determine if the first + // SACK block is reporting duplicate data that lies above the + // cumulative ACK. + if sb.Start.LessThan(rcvdSeg.ackNumber) { + return true + } + + if n > 1 { + sb1 := rcvdSeg.parsedOptions.SACKBlocks[1] + if sb1.End.LessThan(sb1.Start) { + return false + } + + // If the first SACK block is fully covered by second SACK + // block, then the first block is a DSACK block. + if sb.End.LessThanEq(sb1.End) && sb1.Start.LessThanEq(sb.Start) { + return true + } + } + + return false +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { @@ -1363,14 +1309,23 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.SetPipe() } - // Count the duplicates and do the fast retransmit if needed. - rtx := s.checkDuplicateAck(rcvdSeg) + ack := rcvdSeg.ackNumber + fastRetransmit := false + // Do not leave fast recovery, if the ACK is out of range. + if s.fr.active { + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) { + s.leaveRecovery() + } + } else { + // Detect loss by counting the duplicates and enter recovery. + fastRetransmit = s.detectLoss(rcvdSeg) + } // Stash away the current window size. s.sndWnd = rcvdSeg.window - ack := rcvdSeg.ackNumber - // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously @@ -1487,19 +1442,24 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.resendTimer.disable() } } + // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if rtx { - s.resendSegment() + if s.fr.active { + s.lr.DoRecovery(rcvdSeg, fastRetransmit) + // When SACK is enabled data sending is governed by steps in + // RFC 6675 Section 5 recovery steps A-C. + // See: https://tools.ietf.org/html/rfc6675#section-5. + if s.ep.sackPermitted { + return + } } // Send more data now that some of the pending data has been ack'd, or // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo { - s.sendData() - } + s.sendData() } // sendSegment sends the specified segment. diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index d3f92b48c..9818ffa0f 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -30,15 +30,17 @@ const ( maxPayload = 10 tsOptionSize = 12 maxTCPOptionSize = 40 + mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload ) // TestRACKUpdate tests the RACK related fields are updated when an ACK is // received on a SACK enabled connection. func TestRACKUpdate(t *testing.T) { - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + c := context.New(t, uint32(mtu)) defer c.Cleanup() var xmitTime time.Time + probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { // Validate that the endpoint Sender.RACKState is what we expect. if state.Sender.RACKState.XmitTime.Before(xmitTime) { @@ -54,6 +56,7 @@ func TestRACKUpdate(t *testing.T) { if state.Sender.RACKState.RTT == 0 { t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") } + close(probeDone) }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) @@ -73,18 +76,20 @@ func TestRACKUpdate(t *testing.T) { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - time.Sleep(200 * time.Millisecond) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone } // TestRACKDetectReorder tests that RACK detects packet reordering. func TestRACKDetectReorder(t *testing.T) { - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + c := context.New(t, uint32(mtu)) defer c.Cleanup() - const ackNum = 2 - var n int - ch := make(chan struct{}) + const ackNumToVerify = 2 + probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { gotSeq := state.Sender.RACKState.FACK wantSeq := state.Sender.SndNxt @@ -95,7 +100,7 @@ func TestRACKDetectReorder(t *testing.T) { } n++ - if n < ackNum { + if n < ackNumToVerify { if state.Sender.RACKState.Reord { t.Fatalf("RACK reorder detected when there is no reordering") } @@ -105,11 +110,11 @@ func TestRACKDetectReorder(t *testing.T) { if state.Sender.RACKState.Reord == false { t.Fatalf("RACK reorder detection failed") } - close(ch) + close(probeDone) }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNum * maxPayload) + data := buffer.NewView(ackNumToVerify * maxPayload) for i := range data { data[i] = byte(i) } @@ -120,7 +125,7 @@ func TestRACKDetectReorder(t *testing.T) { } bytesRead := 0 - for i := 0; i < ackNum; i++ { + for i := 0; i < ackNumToVerify; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload } @@ -133,5 +138,393 @@ func TestRACKDetectReorder(t *testing.T) { // Wait for the probe function to finish processing the ACK before the // test completes. - <-ch + <-probeDone +} + +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + data := buffer.NewView(numPackets * maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + for i := 0; i < numPackets; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + return data +} + +const ( + validDSACKDetected = 1 + failedToDetectDSACK = 2 + invalidDSACKDetected = 3 +) + +func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) { + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK detects DSACK. + n++ + if n < numACK { + if state.Sender.RACKState.DSACKSeen { + probeDone <- invalidDSACKDetected + } + return + } + + if !state.Sender.RACKState.DSACKSeen { + probeDone <- failedToDetectDSACK + return + } + probeDone <- validDSACKDetected + }) +} + +// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.1. +func TestRACKDetectDSACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 2 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 8 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of +// order segments. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.2. +func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 2 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 10 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + // Send DSACK block for #6 along with out of + // order #9 packet is received. + start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a +// duplicate of out of order packet. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.3 +func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 4 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 10 + sendAndReceive(t, c, numPackets) + + // ACK [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Send SACK indicating #6 packet is missing and received #7 packet. + offset := seqnum.Size(bytesRead + maxPayload) + start := c.IRS.Add(1 + offset) + end := start.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Send SACK with #6 packet is missing and received [7-8] packets. + end = start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Consider #8 packet is duplicated on the network and send DSACK. + dsackStart := c.IRS.Add(1 + offset + maxPayload) + dsackEnd := dsackStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.1. +func TestRACKDetectDSACKSingleDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 4 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 4 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-3] packets and received #4 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // ACK for retransmitted #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by + // sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous +// duplicate subsegments covered by the cumulative acknowledgement. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.2. +func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 5 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 6 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-5] packets and received #6 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Received delayed #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #4 packet. + start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 + // packet by sending DSACK block for #2 packet. + dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not +// covered by the cumulative acknowledgement. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.3. +func TestRACKDetectDSACKDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 5 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 7 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #3 packet. + start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Consider #2 packet has been dropped and SACK #4 packet. + start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end2 := start2.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 + // packet by sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + end1 = end1.Add(seqnum.Size(2 * maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK +// is not the first SACK block. +func TestRACKWithInvalidDSACKBlock(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + const ackNumToVerify = 2 + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK does not detect DSACK when DSACK block is + // not the first SACK block. + n++ + t.Helper() + if state.Sender.RACKState.DSACKSeen { + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } + + if n == ackNumToVerify { + close(probeDone) + } + }) + + numPackets := 10 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + + // Send DSACK block as second block. + start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a7149efd0..7581bdc97 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -63,7 +63,7 @@ func TestGiveUpConnect(t *testing.T) { // Register for notification, then start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) - wq.EventRegister(&waitEntry, waiter.EventOut) + wq.EventRegister(&waitEntry, waiter.EventHUp) defer wq.EventUnregister(&waitEntry) if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { @@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh - if err := ep.LastError(); err != tcpip.ErrAborted { - t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted) - } // Call Connect again to retreive the handshake failure status // and stats updates. @@ -267,7 +264,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) { } // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. - sent := stats.ICMP.V4PacketsSent + sent := stats.ICMP.V4.PacketsSent if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) } @@ -1448,7 +1445,7 @@ func TestSynSent(t *testing.T) { // Start connection attempt. waitEntry, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&waitEntry, waiter.EventOut) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} @@ -2532,10 +2529,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOptBool(tcpip.CorkOption, true) + ep.SocketOptions().SetCorkOption(true) }, func(ep tcpip.Endpoint) { - ep.SetSockOptBool(tcpip.CorkOption, false) + ep.SocketOptions().SetCorkOption(false) }, }, } @@ -2627,7 +2624,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptBool(tcpip.DelayOption, true) + c.EP.SocketOptions().SetDelayOption(true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -2675,7 +2672,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptBool(tcpip.DelayOption, true) + c.EP.SocketOptions().SetDelayOption(true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -2708,7 +2705,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptBool(tcpip.DelayOption, false) + c.EP.SocketOptions().SetDelayOption(false) // Check that data is received. second := c.GetPacket() @@ -2745,8 +2742,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, } for _, test := range tests { @@ -3198,6 +3195,11 @@ loop: case tcpip.ErrWouldBlock: select { case <-ch: + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + } + break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } @@ -3207,14 +3209,10 @@ loop: t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } } - // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) - } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) } @@ -4150,7 +4148,7 @@ func TestReadAfterClosedState(t *testing.T) { // Check that peek works. peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) + n, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { t.Fatalf("Peek failed: %s", err) } @@ -4176,7 +4174,7 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { + if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -4193,9 +4191,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4205,9 +4201,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4218,9 +4212,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4233,9 +4225,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4246,9 +4236,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4261,9 +4249,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4656,13 +4642,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) - } + ep.SocketOptions().SetV6Only(true) case "dual": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) - } + ep.SocketOptions().SetV6Only(false) default: t.Fatalf("unknown network: '%s'", network) } @@ -4998,9 +4980,7 @@ func TestKeepalive(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -5131,6 +5111,7 @@ func TestKeepalive(t *testing.T) { } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendPacket(nil, &context.Headers{ @@ -5175,6 +5156,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki } func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendV6Packet(nil, &context.Headers{ @@ -5238,13 +5220,14 @@ func TestListenBacklogFull(t *testing.T) { // Test acceptance. // Start listening. - listenBacklog := 2 + listenBacklog := 10 if err := c.EP.Listen(listenBacklog); err != nil { t.Fatalf("Listen failed: %s", err) } - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) + lastPortOffset := uint16(0) + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } time.Sleep(50 * time.Millisecond) @@ -5252,7 +5235,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, + SrcPort: context.TestPort + uint16(lastPortOffset), DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(789), @@ -5293,7 +5276,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { @@ -5714,6 +5697,50 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } +func TestSYNRetransmit(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send the same SYN packet multiple times. We should still get a valid SYN-ACK + // reply. + irs := seqnum.Value(789) + for i := 0; i < 5; i++ { + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } + + // Receive the SYN-ACK reply. + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs) + 1), + } + checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...)) +} + func TestSynRcvdBadSeqNumber(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -6071,10 +6098,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Introduce a 25ms latency by delaying the first byte. latency := 25 * time.Millisecond time.Sleep(latency) - rawEP.SendPacketWithTS([]byte{1}, tsVal) + // Send an initial payload with atleast segment overhead size. The receive + // window would not grow for smaller segments. + rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal) pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() + time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. @@ -6347,10 +6377,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T if err != nil { t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) } - gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) - } + gotDelayOption := ep.SocketOptions().GetDelayOption() if gotDelayOption != wantDelayOption { t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) } @@ -6722,6 +6749,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + // drain any older notifications from the notification channel before attempting + // 2nd connection. + select { + case <-ch: + default: + } + // Send a SYN request w/ sequence number higher than // the highest sequence number sent. iss = seqnum.Value(792) @@ -7196,9 +7230,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // Set userTimeout to be the duration to be 1 keepalive // probes. Which means that after the first probe is sent diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 4d7847142..ee55f030c 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -112,6 +112,18 @@ type Headers struct { TCPOpts []byte } +// Options contains options for creating a new test context. +type Options struct { + // EnableV4 indicates whether IPv4 should be enabled. + EnableV4 bool + + // EnableV6 indicates whether IPv4 should be enabled. + EnableV6 bool + + // MTU indicates the maximum transmission unit on the link layer. + MTU uint32 +} + // Context provides an initialized Network stack and a link layer endpoint // for use in TCP tests. type Context struct { @@ -154,10 +166,30 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + return NewWithOpts(t, Options{ + EnableV4: true, + EnableV6: true, + MTU: mtu, }) +} + +// NewWithOpts allocates and initializes a test context containing a new +// stack and a link-layer endpoint with specific options. +func NewWithOpts(t *testing.T, opts Options) *Context { + if opts.MTU == 0 { + panic("MTU must be greater than 0") + } + + stackOpts := stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + } + if opts.EnableV4 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) + } + if opts.EnableV6 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) + } + s := stack.New(stackOpts) const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB @@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - ep := channel.New(1000, mtu, "") + ep := channel.New(1000, opts.MTU, "") wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - opts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, mtu, "")) + wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) } opts2 := stack.NICOptions{Name: "nic2"} if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } + var routeTable []tcpip.Route - s.SetRouteTable([]tcpip.Route{ - { + if opts.EnableV4 { + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, - }, - { + }) + } + + if opts.EnableV6 { + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, NIC: 1, - }, - }) + }) + } + + s.SetRouteTable(routeTable) return &Context{ t: t, @@ -358,7 +395,6 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(header.ICMPv4ProtocolNumber), @@ -373,6 +409,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, const icmpv4VariableHeaderOffset = 4 copy(icmp[icmpv4VariableHeaderOffset:], p1) copy(icmp[header.ICMPv4PayloadOffset:], p2) + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) // Inject packet. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -397,7 +436,6 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // Initialize the IP header. ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(tcp.ProtocolNumber), @@ -554,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } + c.EP.SocketOptions().SetV6Only(v6only) } // GetV6Packet reads a single packet from the link layer endpoint of the context @@ -599,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + TransportProtocol: tcp.ProtocolNumber, + HopLimit: 65, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 7981d469b..38a335840 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { + if t.timer == nil { + // No cleanup needed. + return + } t.timer.Stop() *t = timer{} } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c78549424..153e8c950 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -56,6 +56,8 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d57ed5d79..763d1d654 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -16,8 +16,8 @@ package udp import ( "fmt" + "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -30,10 +30,11 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry - senderAddress tcpip.FullAddress - packetInfo tcpip.IPPacketInfo - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + senderAddress tcpip.FullAddress + destinationAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 // tos stores either the receiveTOS or receiveTClass value. tos uint8 } @@ -77,6 +78,7 @@ func (s EndpointState) String() string { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. @@ -94,22 +96,20 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` dstPort uint16 - v6only bool ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID - multicastLoop bool portFlags ports.Flags bindToDevice tcpip.NICID - broadcast bool - noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -123,17 +123,6 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - // receiveTOS determines if the incoming IPv4 TOS header field is passed - // as ancillary data to ControlMessages on Read. - receiveTOS bool - - // receiveTClass determines if the incoming IPv6 TClass header field is - // passed as ancillary data to ControlMessages on Read. - receiveTClass bool - - // receiveIPPacketInfo determines if the packet info is returned by Read. - receiveIPPacketInfo bool - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -155,8 +144,8 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption + // ops is used to get socket level options. + ops tcpip.SocketOptions } // +stateify savable @@ -186,13 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } + e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -208,6 +198,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue return e } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState() returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + // UniqueID implements stack.TransportEndpoint.UniqueID. func (e *endpoint) UniqueID() uint64 { return e.uniqueID @@ -233,7 +237,7 @@ func (e *endpoint) Close() { e.mu.Lock() e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { + switch e.EndpointState() { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) @@ -256,10 +260,13 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. - e.state = StateClosed + e.setEndpointState(StateClosed) e.mu.Unlock() @@ -301,24 +308,23 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess HasTimestamp: true, Timestamp: p.timestamp, } - e.mu.RLock() - receiveTOS := e.receiveTOS - receiveTClass := e.receiveTClass - receiveIPPacketInfo := e.receiveIPPacketInfo - e.mu.RUnlock() - if receiveTOS { + if e.ops.GetReceiveTOS() { cm.HasTOS = true cm.TOS = p.tos } - if receiveTClass { + if e.ops.GetReceiveTClass() { cm.HasTClass = true // Although TClass is an 8-bit value it's read in the CMsg as a uint32. cm.TClass = uint32(p.tos) } - if receiveIPPacketInfo { + if e.ops.GetReceivePacketInfo() { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } + if e.ops.GetReceiveOriginalDstAddress() { + cm.HasOriginalDstAddress = true + cm.OriginalDstAddress = p.destinationAddress + } return p.data.ToView(), cm, nil } @@ -328,7 +334,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // // Returns true for retry if preparation should be retried. func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { - switch e.state { + switch e.EndpointState() { case StateInitial: case StateConnected: return false, nil @@ -350,7 +356,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return true, nil } @@ -365,9 +371,9 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress - if isBroadcastOrMulticast(localAddr) { + if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" } @@ -382,9 +388,9 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) if err != nil { - return stack.Route{}, 0, err + return nil, 0, err } return r, nicID, nil } @@ -427,7 +433,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c to := opts.To e.mu.RLock() - defer e.mu.RUnlock() + lockReleased := false + defer func() { + if lockReleased { + return + } + e.mu.RUnlock() + }() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { @@ -446,36 +458,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } } - var route *stack.Route - var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) - var dstPort uint16 - if to == nil { - route = &e.route - dstPort = e.dstPort - resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { - // Promote lock to exclusive if using a shared route, given that it may - // need to change in Route.Resolve() call below. - e.mu.RUnlock() - e.mu.Lock() - - // Recheck state after lock was re-acquired. - if e.state != StateConnected { - err = tcpip.ErrInvalidEndpointState - } - if err == nil && route.IsResolutionRequired() { - ch, err = route.Resolve(waker) - } - - e.mu.Unlock() - e.mu.RLock() - - // Recheck state after lock was re-acquired. - if e.state != StateConnected { - err = tcpip.ErrInvalidEndpointState - } - return - } - } else { + route := e.route + dstPort := e.dstPort + if to != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC @@ -487,6 +472,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } + if to.Port == 0 { + // Port 0 is an invalid port to send to. + return 0, nil, tcpip.ErrInvalidEndpointState + } + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -498,17 +488,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r dstPort = dst.Port - resolve = route.Resolve } - if !e.broadcast && route.IsOutboundBroadcast() { + if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { return 0, nil, tcpip.ErrBroadcastDisabled } if route.IsResolutionRequired() { - if ch, err := resolve(nil); err != nil { + if ch, err := route.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { return 0, ch, tcpip.ErrNoLinkAddress } @@ -534,83 +523,46 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { + localPort := e.ID.LocalPort + sendTOS := e.sendTOS + owner := e.owner + noChecksum := e.SocketOptions().GetNoChecksum() + lockReleased = true + e.mu.RUnlock() + + // Do not hold lock when sending as loopback is synchronous and if the UDP + // datagram ends up generating an ICMP response then it can result in a + // deadlock where the ICMP response handling ends up acquiring this endpoint's + // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a + // deadlock if another caller is trying to acquire e.mu in exclusive mode w/ + // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the + // lock can be eventually acquired. + // + // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read + // locking is prohibited. + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v - e.mu.Unlock() - - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = v - e.mu.Unlock() - - case tcpip.NoChecksumOption: - e.mu.Lock() - e.noChecksum = v - e.mu.Unlock() - - case tcpip.ReceiveTOSOption: - e.mu.Lock() - e.receiveTOS = v - e.mu.Unlock() - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrNotSupported - } - - e.mu.Lock() - e.receiveTClass = v - e.mu.Unlock() - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.portFlags.MostRecent = v - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.portFlags.LoadBalanced = v - e.mu.Unlock() - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v - } +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() +} - return nil +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.mu.Lock() + e.portFlags.LoadBalanced = v + e.mu.Unlock() } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. @@ -813,93 +765,10 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() } return nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast - e.mu.RUnlock() - return v, nil - - case tcpip.KeepaliveEnabledOption: - return false, nil - - case tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - return v, nil - - case tcpip.NoChecksumOption: - e.mu.RLock() - v := e.noChecksum - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTOSOption: - e.mu.RLock() - v := e.receiveTOS - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrNotSupported - } - - e.mu.RLock() - v := e.receiveTClass - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil - - case tcpip.ReuseAddressOption: - e.mu.RLock() - v := e.portFlags.MostRecent - e.mu.RUnlock() - - return v, nil - - case tcpip.ReusePortOption: - e.mu.RLock() - v := e.portFlags.LoadBalanced - e.mu.RUnlock() - - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.mu.RLock() - v := e.v6only - e.mu.RUnlock() - - return v, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -974,11 +843,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() - case *tcpip.LingerOption: - e.mu.RLock() - *o = e.linger - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -1009,7 +873,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + if r.RequiresTXTransportChecksum() && (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { @@ -1038,7 +902,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -1050,7 +914,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.state != StateConnected { + if e.EndpointState() != StateConnected { return nil } var ( @@ -1073,7 +937,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { if err != nil { return err } - e.state = StateBound + e.setEndpointState(StateBound) boundPortFlags = e.boundPortFlags } else { if e.ID.LocalPort != 0 { @@ -1081,14 +945,14 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - e.state = StateInitial + e.setEndpointState(StateInitial) } e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() - e.route = stack.Route{} + e.route = nil e.dstPort = 0 return nil @@ -1106,7 +970,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicID := addr.NIC var localPort uint16 - switch e.state { + switch e.EndpointState() { case StateInitial: case StateBound, StateConnected: localPort = e.ID.LocalPort @@ -1141,7 +1005,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { RemoteAddress: r.RemoteAddress, } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { id.LocalAddress = r.LocalAddress } @@ -1149,7 +1013,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv4ProtocolNumber, header.IPv6ProtocolNumber, @@ -1175,7 +1039,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.RegisterNICID = nicID e.effectiveNetProtos = netProtos - e.state = StateConnected + e.setEndpointState(StateConnected) e.rcvMu.Lock() e.rcvReady = true @@ -1197,7 +1061,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. - if e.state != StateBound && e.state != StateConnected { + if state := e.EndpointState(); state != StateBound && state != StateConnected { return tcpip.ErrNotConnected } @@ -1248,7 +1112,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1261,7 +1125,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // wildcard (empty) address, and this is an IPv6 endpoint with v6only // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv6ProtocolNumber, header.IPv4ProtocolNumber, @@ -1269,7 +1133,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicID := addr.NIC - if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { @@ -1292,7 +1156,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { e.effectiveNetProtos = netProtos // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) e.rcvMu.Lock() e.rcvReady = true @@ -1324,7 +1188,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() addr := e.ID.LocalAddress - if e.state == StateConnected { + if e.EndpointState() == StateConnected { addr = e.route.LocalAddress } @@ -1340,7 +1204,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != StateConnected { + if e.EndpointState() != StateConnected { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1366,6 +1230,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.rcvMu.Unlock() } + e.lastErrorMu.Lock() + hasError := e.lastError != nil + e.lastErrorMu.Unlock() + if hasError { + result |= waiter.EventErr + } return result } @@ -1373,10 +1243,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // On IPv4, UDP checksum is optional, and a zero value means the transmitter // omitted the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). -func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool { - if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && - (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) +func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { + if !pkt.RXTransportChecksumValidated && + (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { + netHdr := pkt.Network() + xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) } @@ -1387,8 +1258,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - // Get the header then trim it from the view. +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. @@ -1397,7 +1267,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - if !verifyChecksum(r, hdr, pkt) { + // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap + // packets at "Parse" instead of when handling a packet. + pkt.Data.CapLength(int(hdr.PayloadLength())) + + if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() @@ -1428,9 +1302,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), + }, + destinationAddress: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: header.UDP(hdr).DestinationPort(), }, } packet.data = pkt.Data @@ -1438,7 +1317,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvBufSize += pkt.Data.Size() // Save any useful information from the network header to the packet. - switch r.NetProto { + switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: @@ -1448,9 +1327,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast // address. packetInfo.LocalAddr should hold a unicast address that can be // used to respond to the incoming packet. - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.LocalAddress - packet.packetInfo.NIC = r.NICID() + localAddr := pkt.Network().DestinationAddress() + packet.packetInfo.LocalAddr = localAddr + packet.packetInfo.DestinationAddr = localAddr + packet.packetInfo.NIC = pkt.NICID packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() @@ -1464,23 +1344,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { - e.mu.RLock() - defer e.mu.RUnlock() - - if e.state == StateConnected { + if e.EndpointState() == StateConnected { e.lastErrorMu.Lock() - defer e.lastErrorMu.Unlock() - e.lastError = tcpip.ErrConnectionRefused + e.lastErrorMu.Unlock() + + e.waiterQueue.Notify(waiter.EventErr) + return } } } // State implements tcpip.Endpoint.State. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. @@ -1500,10 +1377,16 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements tcpip.Endpoint.Wait. func (*endpoint) Wait() {} -func isBroadcastOrMulticast(a tcpip.Address) bool { - return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 858c99a45..13b72dc88 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) { } } - if e.state != StateBound && e.state != StateConnected { + state := e.EndpointState() + if state != StateBound && state != StateConnected { return } @@ -113,12 +114,12 @@ func (e *endpoint) Resume(s *stack.Stack) { } var err *tcpip.Error - if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) + if state == StateConnected { + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { panic(err) } - } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 3ae6cc221..14e4648cd 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, - route: r, id: id, pkt: pkt, }) @@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p // it via CreateEndpoint. type ForwarderRequest struct { stack *stack.Stack - route *stack.Route id stack.TransportEndpointID pkt *stack.PacketBuffer } @@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + netHdr := r.pkt.Network() + route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(r.pkt.SourceLinkAddress()) + + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() + route.Release() return nil, err } ep.ID = r.id - ep.route = r.route.Clone() + ep.route = route ep.dstPort = r.id.RemotePort - ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto} - ep.RegisterNICID = r.route.NICID() + ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} + ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags ep.state = StateConnected @@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.pkt) + ep.HandlePacket(r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index da5b1deb2..91420edd3 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets that are targeted at this // protocol but don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + p.stack.Stats().UDP.MalformedPacketsReceived.Increment() return stack.UnknownDestinationPacketMalformed } - if !verifyChecksum(r, hdr, pkt) { - r.Stack().Stats().UDP.ChecksumErrors.Increment() + if !verifyChecksum(hdr, pkt) { + p.stack.Stats().UDP.ChecksumErrors.Increment() return stack.UnknownDestinationPacketMalformed } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b4604ba35..08980c298 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -32,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,6 +56,7 @@ const ( stackPort = 1234 testAddr = "\x0a\x00\x00\x02" testPort = 4096 + invalidPort = 8192 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast @@ -295,7 +298,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, }) } @@ -360,13 +364,9 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetV6Only(true) } else if flow.isBroadcast() { - if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - c.t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetBroadcast(true) } } @@ -453,12 +453,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -490,7 +490,6 @@ func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View // Initialize the IP header. ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ - IHL: header.IPv4MinimumSize, TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, @@ -975,7 +974,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { // provided. func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, true, checkers...) + return testWriteAndVerifyInternal(c, flow, true, checkers...) } // testWriteWithoutDestination sends a packet of the given test flow from the @@ -984,10 +983,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker // checker functions provided. func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, false, checkers...) + return testWriteAndVerifyInternal(c, flow, false, checkers...) } -func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { +func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() @@ -1009,6 +1008,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } c.checkEndpointWriteStats(1, epstats, err) + return payload +} + +func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -1153,6 +1158,39 @@ func TestV4WriteOnConnected(t *testing.T) { testWriteWithoutDestination(c, unicastV4) } +func TestWriteOnConnectedInvalidPort(t *testing.T) { + protocols := map[string]tcpip.NetworkProtocolNumber{ + "ipv4": ipv4.ProtocolNumber, + "ipv6": ipv6.ProtocolNumber, + } + for name, pn := range protocols { + t.Run(name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(pn) + if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, + } + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + if err != nil { + c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + } + if got, want := n, int64(len(payload)); got != want { + c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) + } + + if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } + }) + } +} + // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket // that is bound to a V4 multicast address. func TestWriteOnBoundToV4Multicast(t *testing.T) { @@ -1375,9 +1413,7 @@ func TestReadIPPacketInfo(t *testing.T) { } } - if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { - t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) - } + c.ep.SocketOptions().SetReceivePacketInfo(true) testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ NIC: 1, @@ -1392,6 +1428,93 @@ func TestReadIPPacketInfo(t *testing.T) { } } +func TestReadRecvOriginalDstAddr(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedOriginalDstAddr tcpip.FullAddress + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%#v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) + } + } + + c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) + + testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1415,16 +1538,12 @@ func TestNoChecksum(t *testing.T) { c.createEndpointForFlow(flow) // Disable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(true) // This option is effective on IPv4 only. testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) // Enable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(false) testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) }) } @@ -1452,6 +1571,14 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1586,13 +1713,15 @@ func TestSetTClass(t *testing.T) { } func TestReceiveTosTClass(t *testing.T) { + const RcvTOSOpt = "ReceiveTosOption" + const RcvTClassOpt = "ReceiveTClassOption" + testCases := []struct { - name string - getReceiveOption tcpip.SockOptBool - tests []testFlow + name string + tests []testFlow }{ - {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, - {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, + {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, } for _, testCase := range testCases { for _, flow := range testCase.tests { @@ -1601,29 +1730,32 @@ func TestReceiveTosTClass(t *testing.T) { defer c.cleanup() c.createEndpointForFlow(flow) - option := testCase.getReceiveOption name := testCase.name - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + var optionGetter func() bool + var optionSetter func(bool) + switch name { + case RcvTOSOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTOS + optionSetter = c.ep.SocketOptions().SetReceiveTOS + case RcvTClassOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTClass + optionSetter = c.ep.SocketOptions().SetReceiveTClass + default: + t.Fatalf("unkown test variant: %s", name) } + + // Verify that setting and reading the option works. + v := optionGetter() // Test for expected default value. if v != false { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } want := true - if err := c.ep.SetSockOptBool(option, want); err != nil { - c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) - } - - got, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) - } + optionSetter(want) + got := optionGetter() if got != want { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) } @@ -1633,10 +1765,10 @@ func TestReceiveTosTClass(t *testing.T) { if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %s", err) } - switch option { - case tcpip.ReceiveTClassOption: + switch name { + case RcvTClassOpt: testRead(c, flow, checker.ReceiveTClass(testTOS)) - case tcpip.ReceiveTOSOption: + case RcvTOSOpt: testRead(c, flow, checker.ReceiveTOS(testTOS)) default: t.Fatalf("unknown test variant: %s", name) @@ -1783,28 +1915,31 @@ func TestV4UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize - wantLen := len(payload) + wantPayloadLen := len(payload) if tc.largePayload { // To work out the data size we need to simulate what the sender would // have done. The wanted size is the total available minus the sum of // the headers in the UDP AND ICMP packets, given that we know the test // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - + wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength } // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. // Add back the two headers within the payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) - + payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + wantDgramLen := wantPayloadLen + header.UDPMinimumSize + + if got, want := len(origDgram), wantDgramLen; got != want { + t.Fatalf("got len(origDgram) = %d, want = %d", got, want) } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) + // Correct UDP length to access payload. + origDgram.SetLength(uint16(wantDgramLen)) + + if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) { + t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want) } }) } @@ -1879,20 +2014,23 @@ func TestV6UnknownDestination(t *testing.T) { icmpPkt := header.ICMPv6(hdr.Payload()) payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) + wantPayloadLen := len(payload) if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize } + wantDgramLen := wantPayloadLen + header.UDPMinimumSize // In case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + payloadIPHeader.SetPayloadLength(uint16(wantDgramLen)) origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want { + t.Fatalf("got len(origDgram) = %d, want = %d", got, want) } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) + // Correct UDP length to access payload. + origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize)) + if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" { + t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff) } }) } @@ -1951,12 +2089,12 @@ func TestShortHeader(t *testing.T) { // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -2391,17 +2529,13 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) } - if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) - } + ep.SocketOptions().SetBroadcast(true) if n, _, err := ep.Write(data, opts); err != nil { t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) } - if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { - t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) - } + ep.SocketOptions().SetBroadcast(false) if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) @@ -2409,3 +2543,67 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { }) } } + +func TestReceiveShortLength(t *testing.T) { + flows := []testFlow{unicastV4, unicastV6} + for _, flow := range flows { + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err) + } + + payload := newPayload() + extraBytes := []byte{1, 2, 3, 4} + h := flow.header4Tuple(incoming) + var buf buffer.View + var proto tcpip.NetworkProtocolNumber + + // Build packets with extra bytes not accounted for in the UDP length + // field. + var udp header.UDP + if flow.isV4() { + buf = c.buildV4Packet(payload, &h) + buf = append(buf, extraBytes...) + ip := header.IPv4(buf) + ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes))) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + proto = ipv4.ProtocolNumber + udp = ip.Payload() + } else { + buf = c.buildV6Packet(payload, &h) + buf = append(buf, extraBytes...) + ip := header.IPv6(buf) + ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes))) + proto = ipv6.ProtocolNumber + udp = ip.Payload() + } + + if diff := cmp.Diff(payload, udp.Payload()); diff != "" { + t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff) + } + + c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Try to receive the data. + v, _, err := c.ep.Read(nil) + if err != nil { + t.Fatalf("c.ep.Read(nil): %s", err) + } + + // Check the payload is read back without extra bytes. + if diff := cmp.Diff(buffer.View(payload), v); diff != "" { + t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go index 70945f234..e41769017 100644 --- a/pkg/test/criutil/criutil.go +++ b/pkg/test/criutil/criutil.go @@ -54,14 +54,20 @@ func ResolvePath(executable string) string { } } + // Favor /usr/local/bin, if it exists. + localBin := fmt.Sprintf("/usr/local/bin/%s", executable) + if _, err := os.Stat(localBin); err == nil { + return localBin + } + // Try to find via the path. - guess, err := exec.LookPath(executable) + guess, _ := exec.LookPath(executable) if err == nil { return guess } - // Return a default path. - return fmt.Sprintf("/usr/local/bin/%s", executable) + // Return a bare path; this generates a suitable error. + return executable } // NewCrictl returns a Crictl configured with a timeout and an endpoint over diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 64d17f661..2bf0a22ff 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -17,6 +17,7 @@ package dockerutil import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "net" @@ -351,6 +352,9 @@ func (c *Container) SandboxPid(ctx context.Context) (int, error) { return resp.ContainerJSONBase.State.Pid, nil } +// ErrNoIP indicates that no IP address is available. +var ErrNoIP = errors.New("no IP available") + // FindIP returns the IP address of the container. func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) { resp, err := c.client.ContainerInspect(ctx, c.id) @@ -365,7 +369,7 @@ func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) { ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress) } if ip == nil { - return net.IP{}, fmt.Errorf("invalid IP: %q", ip) + return net.IP{}, ErrNoIP } return ip, nil } diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go index 4c739c9e9..bf968acec 100644 --- a/pkg/test/dockerutil/exec.go +++ b/pkg/test/dockerutil/exec.go @@ -77,11 +77,6 @@ func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Proc return Process{}, fmt.Errorf("exec attach failed with err: %v", err) } - if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil { - hijack.Close() - return Process{}, fmt.Errorf("exec start failed with err: %v", err) - } - return Process{ container: c, execid: resp.ID, diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 49ab87c58..fdd416b5e 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -36,7 +36,6 @@ import ( "path/filepath" "strconv" "strings" - "sync/atomic" "syscall" "testing" "time" @@ -49,7 +48,10 @@ import ( ) var ( - checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support") + checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support") + partition = flag.Int("partition", 1, "partition number, this is 1-indexed") + totalPartitions = flag.Int("total_partitions", 1, "total number of partitions") + isRunningWithHostNet = flag.Bool("hostnet", false, "whether test is running with hostnet") ) // IsCheckpointSupported returns the relevant command line flag. @@ -57,6 +59,11 @@ func IsCheckpointSupported() bool { return *checkpoint } +// IsRunningWithHostNet returns the relevant command line flag. +func IsRunningWithHostNet() bool { + return *isRunningWithHostNet +} + // ImageByName mangles the image name used locally. This depends on the image // build infrastructure in images/ and tools/vm. func ImageByName(name string) string { @@ -249,14 +256,25 @@ func writeSpec(dir string, spec *specs.Spec) error { // idRandomSrc is a pseudo random generator used to in RandomID. var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano())) +// idRandomSrcMtx is the mutex protecting idRandomSrc.Read from being used +// concurrently in differnt goroutines. +var idRandomSrcMtx sync.Mutex + // RandomID returns 20 random bytes following the given prefix. func RandomID(prefix string) string { // Read 20 random bytes. b := make([]byte, 20) + // Rand.Read is not safe for concurrent use. Packetimpact tests can be run in + // parallel now, so we have to protect the Read with a mutex. Otherwise we'll + // run into name conflicts. + // https://golang.org/pkg/math/rand/#Rand.Read + idRandomSrcMtx.Lock() // "[Read] always returns len(p) and a nil error." --godoc if _, err := idRandomSrc.Read(b); err != nil { + idRandomSrcMtx.Unlock() panic("rand.Read failed: " + err.Error()) } + idRandomSrcMtx.Unlock() if prefix != "" { prefix = prefix + "-" } @@ -417,33 +435,35 @@ func StartReaper() func() { // WaitUntilRead reads from the given reader until the wanted string is found // or until timeout. -func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time.Duration) error { +func WaitUntilRead(r io.Reader, want string, timeout time.Duration) error { sc := bufio.NewScanner(r) - if split != nil { - sc.Split(split) - } // done must be accessed atomically. A value greater than 0 indicates // that the read loop can exit. - var done uint32 - doneCh := make(chan struct{}) + doneCh := make(chan bool) + defer close(doneCh) go func() { for sc.Scan() { t := sc.Text() if strings.Contains(t, want) { - atomic.StoreUint32(&done, 1) - close(doneCh) - break + doneCh <- true + return } - if atomic.LoadUint32(&done) > 0 { - break + select { + case <-doneCh: + return + default: } } + doneCh <- false }() + select { case <-time.After(timeout): - atomic.StoreUint32(&done, 1) return fmt.Errorf("timeout waiting to read %q", want) - case <-doneCh: + case res := <-doneCh: + if !res { + return fmt.Errorf("reader closed while waiting to read %q", want) + } return nil } } @@ -509,7 +529,8 @@ func TouchShardStatusFile() error { } // TestIndicesForShard returns indices for this test shard based on the -// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. +// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars, as well as +// the passed partition flags. // // If either of the env vars are not present, then the function will return all // tests. If there are more shards than there are tests, then the returned list @@ -534,6 +555,11 @@ func TestIndicesForShard(numTests int) ([]int, error) { } } + // Combine with the partitions. + partitionSize := shardTotal + shardTotal = (*totalPartitions) * shardTotal + shardIndex = partitionSize*(*partition-1) + shardIndex + // Calculate! var indices []int numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal))) diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index 5c4b9e8e9..a38ffc19d 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -53,40 +53,40 @@ func randomFilename() (string, error) { func TestConnectFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } if _, err := Connect(name, false); err == nil { - t.Fatalf("connect was successful, expected err") + t.Fatalf("Connect was successful, expected err") } } func TestBindFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() if _, err = BindAndListen(name, false); err == nil { - t.Fatalf("second bind succeeded, expected non-nil err") + t.Fatalf("Second bind succeeded, expected non-nil err") } } func TestMultipleAccept(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() @@ -99,7 +99,8 @@ func TestMultipleAccept(t *testing.T) { defer wg.Done() s, err := Connect(name, false) if err != nil { - t.Fatalf("connect failed, got err %v expected nil", err) + t.Errorf("Connect failed, got err %v expected nil", err) + return } s.Close() }() @@ -109,7 +110,7 @@ func TestMultipleAccept(t *testing.T) { for i := 0; i < backlog; i++ { s, err := ss.Accept() if err != nil { - t.Errorf("accept failed, got err %v expected nil", err) + t.Errorf("Accept failed, got err %v expected nil", err) continue } s.Close() @@ -119,35 +120,35 @@ func TestMultipleAccept(t *testing.T) { func TestServerClose(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } // Make sure the first close succeeds. if err := ss.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := ss.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } // Bind a server. ss, err := BindAndListen(name, packet) if err != nil { - t.Fatalf("error binding, got %v expected nil", err) + t.Fatalf("Error binding, got %v expected nil", err) } defer ss.Close() @@ -165,7 +166,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { // Connect the client. client, err := Connect(name, packet) if err != nil { - t.Fatalf("error connecting, got %v expected nil", err) + t.Fatalf("Error connecting, got %v expected nil", err) } // Grab the server handle. @@ -173,7 +174,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { case server := <-acceptSocket: return server, client case err := <-acceptErr: - t.Fatalf("accept error: %v", err) + t.Fatalf("Accept error: %v", err) } panic("unreachable") } @@ -186,17 +187,17 @@ func TestSendRecv(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. b := [][]byte{{'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -211,17 +212,17 @@ func TestSymmetric(t *testing.T) { // Write on the server. w := server.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -233,13 +234,13 @@ func TestPacket(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Write on the client again. w = client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. @@ -249,19 +250,19 @@ func TestPacket(t *testing.T) { b := [][]byte{{'b', 'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Do it again. r = server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -271,12 +272,12 @@ func TestClose(t *testing.T) { // Make sure the first close succeeds. if err := client.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := client.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } @@ -294,17 +295,17 @@ func TestNonBlockingSend(t *testing.T) { // We're good. That's what we wanted. blockCount++ } else { - t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err) } } } if blockCount == 1000 { // Shouldn't have _always_ blocked. - t.Fatalf("socket always blocked!") + t.Fatalf("Socket always blocked!") } else if blockCount == 0 { // Should have started blocking eventually. - t.Fatalf("socket never blocked!") + t.Fatalf("Socket never blocked!") } } @@ -319,25 +320,25 @@ func TestNonBlockingRecv(t *testing.T) { // Expected to block immediately. _, err := r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } // Put some data in the pipe. w := server.Writer(false) if n, err := w.WriteVec(b); n != 1 || err != nil { - t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it not to block. if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it to return a block error again. r = client.Reader(false) _, err = r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } } @@ -349,17 +350,17 @@ func TestRecvVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c'}, {'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[1][0] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) } } @@ -371,17 +372,17 @@ func TestSendVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c', 'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[0][1] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) } } @@ -394,23 +395,23 @@ func TestSendFDsNotEnabled(t *testing.T) { w := server.Writer(true) w.PackFDs(0, 1, 2) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client, without enabling FDs. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Make sure the FDs are not received. fds, err := r.ExtractFDs() if len(fds) != 0 || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) } } @@ -418,7 +419,7 @@ func sendFDs(t *testing.T, s *Socket, fds []int) { w := s.Writer(true) w.PackFDs(fds...) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err) } } @@ -428,7 +429,7 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Count the number of FDs. preEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } // Read on the client. @@ -438,31 +439,31 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { r.EnableFDs(enableSize) } if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Count the new number of FDs. postEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(preEntries)+expected != len(postEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) } // Make sure the FDs are there. fds, err := r.ExtractFDs() if len(fds) != expected || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) } // Make sure they are different from the originals. for i := 0; i < len(fds); i++ { if fds[i] == origFDs[i] { - t.Errorf("got original fd for index %d, expected different", i) + t.Errorf("Got original fd for index %d, expected different", i) } } @@ -480,10 +481,10 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Make sure the count is back to normal. finalEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(finalEntries) != len(preEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) } } @@ -567,7 +568,7 @@ func TestGetPeerCred(t *testing.T) { } if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) + t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) } } @@ -594,53 +595,53 @@ func TestGetPeerCredFailure(t *testing.T) { want := "bad file descriptor" if _, err := s.GetPeerCred(); err == nil || err.Error() != want { - t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want) + t.Errorf("s.GetPeerCred() = %v, want = %s", err, want) } } func TestAcceptClosed(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Fatalf("Close failed, got err %v expected nil", err) } if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } } func TestCloseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Errorf("Close failed, got err %v expected nil", err) } - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -649,28 +650,28 @@ func TestCloseAfterAcceptStart(t *testing.T) { func TestReleaseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) fd, err := ss.Release() if err != nil { - t.Fatalf("Release failed, got err %v expected nil", err) + t.Errorf("Release failed, got err %v expected nil", err) } syscall.Close(fd) - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -688,7 +689,7 @@ func TestControlMessage(t *testing.T) { cm.PackFDs(want...) got, err := cm.ExtractFDs() if err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) + t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) } } } @@ -705,11 +706,13 @@ func benchmarkSendRecv(b *testing.B, packet bool) { for i := 0; i < b.N; i++ { n, err := server.Read(buf) if n != 1 || err != nil { - b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + return } n, err = server.Write(buf) if n != 1 || err != nil { - b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + return } } }() diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 9b1e7a085..79db8895b 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -167,7 +167,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) { return n, err } -// Writer implements io.Writer.Write. +// Write implements io.Writer.Write. func (rw *IOReadWriter) Write(src []byte) (int, error) { n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts) end, ok := rw.Addr.AddLength(uint64(n)) diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 67a950444..83d4f893a 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -119,7 +119,10 @@ type EntryCallback interface { // The callback is supposed to perform minimal work, and cannot call // any method on the queue itself because it will be locked while the // callback is running. - Callback(e *Entry) + // + // The mask indicates the events that occurred and that the entry is + // interested in. + Callback(e *Entry, mask EventMask) } // Entry represents a waiter that can be add to the a wait queue. It can @@ -140,7 +143,7 @@ type channelCallback struct { } // Callback implements EntryCallback.Callback. -func (c *channelCallback) Callback(*Entry) { +func (c *channelCallback) Callback(*Entry, EventMask) { select { case c.ch <- struct{}{}: default: @@ -168,7 +171,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { // // +stateify savable type Queue struct { - list waiterList `state:"zerovalue"` + list waiterList mu sync.RWMutex `state:"nosave"` } @@ -193,8 +196,8 @@ func (q *Queue) EventUnregister(e *Entry) { func (q *Queue) Notify(mask EventMask) { q.mu.RLock() for e := q.list.Front(); e != nil; e = e.Next() { - if mask&e.mask != 0 { - e.Callback.Callback(e) + if m := mask & e.mask; m != 0 { + e.Callback.Callback(e, m) } } q.mu.RUnlock() diff --git a/pkg/waiter/waiter_test.go b/pkg/waiter/waiter_test.go index c1b94a4f3..6928f28b4 100644 --- a/pkg/waiter/waiter_test.go +++ b/pkg/waiter/waiter_test.go @@ -20,12 +20,12 @@ import ( ) type callbackStub struct { - f func(e *Entry) + f func(e *Entry, m EventMask) } // Callback implements EntryCallback.Callback. -func (c *callbackStub) Callback(e *Entry) { - c.f(e) +func (c *callbackStub) Callback(e *Entry, m EventMask) { + c.f(e, m) } func TestEmptyQueue(t *testing.T) { @@ -36,7 +36,7 @@ func TestEmptyQueue(t *testing.T) { // Register then unregister a waiter, then notify the queue. cnt := 0 - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}} q.EventRegister(&e, EventIn) q.EventUnregister(&e) q.Notify(EventIn) @@ -49,7 +49,7 @@ func TestMask(t *testing.T) { // Register a waiter. var q Queue var cnt int - e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + e := Entry{Callback: &callbackStub{func(*Entry, EventMask) { cnt++ }}} q.EventRegister(&e, EventIn|EventErr) // Notify with an overlapping mask. @@ -101,11 +101,14 @@ func TestConcurrentRegistration(t *testing.T) { for i := 0; i < concurrency; i++ { go func() { var e Entry - e.Callback = &callbackStub{func(entry *Entry) { + e.Callback = &callbackStub{func(entry *Entry, mask EventMask) { cnt++ if entry != &e { t.Errorf("entry = %p, want %p", entry, &e) } + if mask != EventIn { + t.Errorf("mask = %#x want %#x", mask, EventIn) + } }} // Wait for notification, then register. @@ -158,11 +161,14 @@ func TestConcurrentNotification(t *testing.T) { // Register waiters. for i := 0; i < waiterCount; i++ { var e Entry - e.Callback = &callbackStub{func(entry *Entry) { + e.Callback = &callbackStub{func(entry *Entry, mask EventMask) { atomic.AddInt32(&cnt, 1) if entry != &e { t.Errorf("entry = %p, want %p", entry, &e) } + if mask != EventIn { + t.Errorf("mask = %#x want %#x", mask, EventIn) + } }} q.EventRegister(&e, EventIn|EventErr) |