diff options
Diffstat (limited to 'pkg')
190 files changed, 8104 insertions, 3565 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a0654df2f..8fa61d6f7 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -21,6 +21,7 @@ go_library( "epoll_amd64.go", "epoll_arm64.go", "errors.go", + "errqueue.go", "eventfd.go", "exec.go", "fadvise.go", diff --git a/pkg/abi/linux/errqueue.go b/pkg/abi/linux/errqueue.go new file mode 100644 index 000000000..3905d4222 --- /dev/null +++ b/pkg/abi/linux/errqueue.go @@ -0,0 +1,93 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + +// Socket error origin codes as defined in include/uapi/linux/errqueue.h. +const ( + SO_EE_ORIGIN_NONE = 0 + SO_EE_ORIGIN_LOCAL = 1 + SO_EE_ORIGIN_ICMP = 2 + SO_EE_ORIGIN_ICMP6 = 3 +) + +// SockExtendedErr represents struct sock_extended_err in Linux defined in +// include/uapi/linux/errqueue.h. +// +// +marshal +type SockExtendedErr struct { + Errno uint32 + Origin uint8 + Type uint8 + Code uint8 + Pad uint8 + Info uint32 + Data uint32 +} + +// SockErrCMsg represents the IP*_RECVERR control message. +type SockErrCMsg interface { + marshal.Marshallable + + CMsgLevel() uint32 + CMsgType() uint32 +} + +// SockErrCMsgIPv4 is the IP_RECVERR control message used in +// recvmsg(MSG_ERRQUEUE) by ipv4 sockets. This is equilavent to `struct errhdr` +// defined in net/ipv4/ip_sockglue.c:ip_recv_error(). +// +// +marshal +type SockErrCMsgIPv4 struct { + SockExtendedErr + Offender SockAddrInet +} + +var _ SockErrCMsg = (*SockErrCMsgIPv4)(nil) + +// CMsgLevel implements SockErrCMsg.CMsgLevel. +func (*SockErrCMsgIPv4) CMsgLevel() uint32 { + return SOL_IP +} + +// CMsgType implements SockErrCMsg.CMsgType. +func (*SockErrCMsgIPv4) CMsgType() uint32 { + return IP_RECVERR +} + +// SockErrCMsgIPv6 is the IPV6_RECVERR control message used in +// recvmsg(MSG_ERRQUEUE) by ipv6 sockets. This is equilavent to `struct errhdr` +// defined in net/ipv6/datagram.c:ipv6_recv_error(). +// +// +marshal +type SockErrCMsgIPv6 struct { + SockExtendedErr + Offender SockAddrInet6 +} + +var _ SockErrCMsg = (*SockErrCMsgIPv6)(nil) + +// CMsgLevel implements SockErrCMsg.CMsgLevel. +func (*SockErrCMsgIPv6) CMsgLevel() uint32 { + return SOL_IPV6 +} + +// CMsgType implements SockErrCMsg.CMsgType. +func (*SockErrCMsgIPv6) CMsgType() uint32 { + return IPV6_RECVERR +} diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index d91c97a64..1070b457c 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -19,16 +19,22 @@ import ( "gvisor.dev/gvisor/pkg/marshal/primitive" ) +// FUSEOpcode is a FUSE operation code. +// // +marshal type FUSEOpcode uint32 +// FUSEOpID is a FUSE operation ID. +// // +marshal type FUSEOpID uint64 // FUSE_ROOT_ID is the id of root inode. const FUSE_ROOT_ID = 1 -// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. +// Opcodes for FUSE operations. +// +// Analogous to the opcodes in include/linux/fuse.h. const ( FUSE_LOOKUP FUSEOpcode = 1 FUSE_FORGET = 2 /* no reply */ diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 0adff8dff..2424884c1 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -43,10 +43,10 @@ const ( SEMVMX = 32767 SEMAEM = SEMVMX - // followings are unused in kernel SEMUME = SEMOPM SEMMNU = SEMMNS SEMMAP = SEMMNS + SEMUSZ = 20 ) const SEM_UNDO = 0x1000 diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index d156d41e4..556892dc3 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -111,12 +111,12 @@ type SockType int // Socket types, from linux/net.h. const ( SOCK_STREAM SockType = 1 - SOCK_DGRAM = 2 - SOCK_RAW = 3 - SOCK_RDM = 4 - SOCK_SEQPACKET = 5 - SOCK_DCCP = 6 - SOCK_PACKET = 10 + SOCK_DGRAM SockType = 2 + SOCK_RAW SockType = 3 + SOCK_RDM SockType = 4 + SOCK_SEQPACKET SockType = 5 + SOCK_DCCP SockType = 6 + SOCK_PACKET SockType = 10 ) // SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are @@ -448,6 +448,8 @@ type ControlMessageCredentials struct { // A ControlMessageIPPacketInfo is IP_PKTINFO socket control message. // // ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. +// +// +stateify savable type ControlMessageIPPacketInfo struct { NIC int32 LocalAddr InetAddr diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index a4f4b2c5e..fdfe31417 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -27,6 +27,7 @@ import ( "io" "sort" "sync/atomic" + "testing" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -34,12 +35,6 @@ import ( "github.com/bazelbuild/rules_go/go/tools/coverdata" ) -// KcovAvailable returns whether the kcov coverage interface is available. It is -// available as long as coverage is enabled for some files. -func KcovAvailable() bool { - return len(coverdata.Cover.Blocks) > 0 -} - // coverageMu must be held while accessing coverdata.Cover. This prevents // concurrent reads/writes from multiple threads collecting coverage data. var coverageMu sync.RWMutex @@ -47,6 +42,22 @@ var coverageMu sync.RWMutex // once ensures that globalData is only initialized once. var once sync.Once +// blockBitLength is the number of bits used to represent coverage block index +// in a synthetic PC (the rest are used to represent the file index). Even +// though a PC has 64 bits, we only use the lower 32 bits because some users +// (e.g., syzkaller) may truncate that address to a 32-bit value. +// +// As of this writing, there are ~1200 files that can be instrumented and at +// most ~1200 blocks per file, so 16 bits is more than enough to represent every +// file and every block. +const blockBitLength = 16 + +// KcovAvailable returns whether the kcov coverage interface is available. It is +// available as long as coverage is enabled for some files. +func KcovAvailable() bool { + return len(coverdata.Cover.Blocks) > 0 +} + var globalData struct { // files is the set of covered files sorted by filename. It is calculated at // startup. @@ -104,14 +115,14 @@ var coveragePool = sync.Pool{ // coverage tools, we reset the global coverage data every time this function is // run. func ConsumeCoverageData(w io.Writer) int { - once.Do(initCoverageData) + InitCoverageData() coverageMu.Lock() defer coverageMu.Unlock() total := 0 var pcBuffer [8]byte - for fileIndex, file := range globalData.files { + for fileNum, file := range globalData.files { counters := coverdata.Cover.Counters[file] for index := 0; index < len(counters); index++ { if atomic.LoadUint32(&counters[index]) == 0 { @@ -119,7 +130,7 @@ func ConsumeCoverageData(w io.Writer) int { } // Non-zero coverage data found; consume it and report as a PC. atomic.StoreUint32(&counters[index], 0) - pc := globalData.syntheticPCs[fileIndex][index] + pc := globalData.syntheticPCs[fileNum][index] usermem.ByteOrder.PutUint64(pcBuffer[:], pc) n, err := w.Write(pcBuffer[:]) if err != nil { @@ -142,31 +153,84 @@ func ConsumeCoverageData(w io.Writer) int { return total } -// initCoverageData initializes globalData. It should only be called once, -// before any kcov data is written. -func initCoverageData() { - // First, order all files. Then calculate synthetic PCs for every block - // (using the well-defined ordering for files as well). - for file := range coverdata.Cover.Blocks { - globalData.files = append(globalData.files, file) +// InitCoverageData initializes globalData. It should be called before any kcov +// data is written. +func InitCoverageData() { + once.Do(func() { + // First, order all files. Then calculate synthetic PCs for every block + // (using the well-defined ordering for files as well). + for file := range coverdata.Cover.Blocks { + globalData.files = append(globalData.files, file) + } + sort.Strings(globalData.files) + + for fileNum, file := range globalData.files { + blocks := coverdata.Cover.Blocks[file] + pcs := make([]uint64, 0, len(blocks)) + for blockNum := range blocks { + pcs = append(pcs, calculateSyntheticPC(fileNum, blockNum)) + } + globalData.syntheticPCs = append(globalData.syntheticPCs, pcs) + } + }) +} + +// Symbolize prints information about the block corresponding to pc. +func Symbolize(out io.Writer, pc uint64) error { + fileNum, blockNum := syntheticPCToIndexes(pc) + file, err := fileFromIndex(fileNum) + if err != nil { + return err + } + block, err := blockFromIndex(file, blockNum) + if err != nil { + return err } - sort.Strings(globalData.files) - - // nextSyntheticPC is the first PC that we generate for a block. - // - // This uses a standard-looking kernel range for simplicity. - // - // FIXME(b/160639712): This is only necessary because syzkaller requires - // addresses in the kernel range. If we can remove this constraint, then we - // should be able to use the actual addresses. - var nextSyntheticPC uint64 = 0xffffffff80000000 - for _, file := range globalData.files { - blocks := coverdata.Cover.Blocks[file] - thisFile := make([]uint64, 0, len(blocks)) - for range blocks { - thisFile = append(thisFile, nextSyntheticPC) - nextSyntheticPC++ // Advance. + writeBlock(out, pc, file, block) + return nil +} + +// WriteAllBlocks prints all information about all blocks along with their +// corresponding synthetic PCs. +func WriteAllBlocks(out io.Writer) { + for fileNum, file := range globalData.files { + for blockNum, block := range coverdata.Cover.Blocks[file] { + writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block) } - globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile) } } + +func calculateSyntheticPC(fileNum int, blockNum int) uint64 { + return (uint64(fileNum) << blockBitLength) + uint64(blockNum) +} + +func syntheticPCToIndexes(pc uint64) (fileNum int, blockNum int) { + return int(pc >> blockBitLength), int(pc & ((1 << blockBitLength) - 1)) +} + +// fileFromIndex returns the name of the file in the sorted list of instrumented files. +func fileFromIndex(i int) (string, error) { + total := len(globalData.files) + if i < 0 || i >= total { + return "", fmt.Errorf("file index out of range: [%d] with length %d", i, total) + } + return globalData.files[i], nil +} + +// blockFromIndex returns the i-th block in the given file. +func blockFromIndex(file string, i int) (testing.CoverBlock, error) { + blocks, ok := coverdata.Cover.Blocks[file] + if !ok { + return testing.CoverBlock{}, fmt.Errorf("instrumented file %s does not exist", file) + } + total := len(blocks) + if i < 0 || i >= total { + return testing.CoverBlock{}, fmt.Errorf("block index out of range: [%d] with length %d", i, total) + } + return blocks[i], nil +} + +func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) { + io.WriteString(out, fmt.Sprintf("%#x\n", pc)) + io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1)) +} diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index f7f9dbf86..69eeb7528 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -36,3 +36,14 @@ package cpuid // On arm64, features are numbered according to the ELF HWCAP definition. // arch/arm64/include/uapi/asm/hwcap.h type Feature int + +// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a +// subset of the host feature set. +type ErrIncompatible struct { + message string +} + +// Error implements error. +func (e ErrIncompatible) Error() string { + return e.message +} diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index 17a89c00d..392711e8f 100644 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go @@ -681,17 +681,6 @@ func (fs *FeatureSet) Intel() bool { return fs.VendorID == intelVendorID } -// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a -// subset of the host feature set. -type ErrIncompatible struct { - message string -} - -// Error implements error. -func (e ErrIncompatible) Error() string { - return e.message -} - // CheckHostCompatible returns nil if fs is a subset of the host feature set. func (fs *FeatureSet) CheckHostCompatible() error { hfs := HostFeatureSet() diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index aa8e4e1f3..cc31d0175 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -11,7 +11,8 @@ go_library( "futex_linux.go", "io.go", "packet_window_allocator.go", - "packet_window_mmap.go", + "packet_window_mmap_amd64.go", + "packet_window_mmap_arm64.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap_amd64.go index 869183b11..869183b11 100644 --- a/pkg/flipcall/packet_window_mmap.go +++ b/pkg/flipcall/packet_window_mmap_amd64.go diff --git a/pkg/flipcall/packet_window_mmap_arm64.go b/pkg/flipcall/packet_window_mmap_arm64.go new file mode 100644 index 000000000..b9c9c44f6 --- /dev/null +++ b/pkg/flipcall/packet_window_mmap_arm64.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package flipcall + +import ( + "syscall" +) + +// Return a memory mapping of the pwd in memory that can be shared outside the sandbox. +func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) { + m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) + return m, err +} diff --git a/pkg/log/json.go b/pkg/log/json.go index bdf9d691e..8c52dcc87 100644 --- a/pkg/log/json.go +++ b/pkg/log/json.go @@ -27,8 +27,8 @@ type jsonLog struct { } // MarshalJSON implements json.Marshaler.MarashalJSON. -func (lv Level) MarshalJSON() ([]byte, error) { - switch lv { +func (l Level) MarshalJSON() ([]byte, error) { + switch l { case Warning: return []byte(`"warning"`), nil case Info: @@ -36,20 +36,20 @@ func (lv Level) MarshalJSON() ([]byte, error) { case Debug: return []byte(`"debug"`), nil default: - return nil, fmt.Errorf("unknown level %v", lv) + return nil, fmt.Errorf("unknown level %v", l) } } // UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal // from both string names and integers. -func (lv *Level) UnmarshalJSON(b []byte) error { +func (l *Level) UnmarshalJSON(b []byte) error { switch s := string(b); s { case "0", `"warning"`: - *lv = Warning + *l = Warning case "1", `"info"`: - *lv = Info + *l = Info case "2", `"debug"`: - *lv = Debug + *l = Debug default: return fmt.Errorf("unknown level %q", s) } diff --git a/pkg/log/log.go b/pkg/log/log.go index 37e0605ad..2e3408357 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -356,7 +356,7 @@ func CopyStandardLogTo(l Level) error { case Warning: f = Warningf default: - return fmt.Errorf("Unknown log level %v", l) + return fmt.Errorf("unknown log level %v", l) } stdlog.SetOutput(linewriter.NewWriter(func(p []byte) { diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 6acee90ef..aea7dde38 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -350,9 +350,13 @@ type VerifyParams struct { // For verifyMetadata, params.data is not needed. It only accesses params.tree // for the raw root hash. func verifyMetadata(params *VerifyParams, layout *Layout) error { - root := make([]byte, layout.digestSize) - if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil { - return fmt.Errorf("failed to read root hash: %w", err) + var root []byte + // Only read the root hash if we expect that the Merkle tree file is non-empty. + if params.Size != 0 { + root = make([]byte, layout.digestSize) + if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil { + return fmt.Errorf("failed to read root hash: %w", err) + } } descriptor := VerityDescriptor{ Name: params.Name, diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e605b14c..2e3d427ae 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -678,16 +678,15 @@ func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string // case. defer checkDeleted(h, dst) } else { + // If the type is different than the destination, then + // we expect the rename to fail. We expect that this + // is returned. + // + // If the file being renamed to itself, this is + // technically allowed and a no-op, but all the + // triggers will fire. if !selfRename { - // If the type is different than the - // destination, then we expect the rename to - // fail. We expect ensure that this is - // returned. expectedErr = syscall.EINVAL - } else { - // This is the file being renamed to itself. - // This is technically allowed and a no-op, but - // all the triggers will fire. } dst.Close() } diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index e7406b374..a29f06ddb 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -197,33 +197,33 @@ func BenchmarkSendRecv(b *testing.B) { for i := 0; i < b.N; i++ { tag, m, err := recv(server, maximumLength, msgRegistry.get) if err != nil { - b.Fatalf("recv got err %v expected nil", err) + b.Errorf("recv got err %v expected nil", err) } if tag != Tag(1) { - b.Fatalf("got tag %v expected 1", tag) + b.Errorf("got tag %v expected 1", tag) } if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %T expected *Rflush", m) + b.Errorf("got message %T expected *Rflush", m) } if err := send(server, Tag(2), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) + b.Errorf("send got err %v expected nil", err) } } }() b.ResetTimer() for i := 0; i < b.N; i++ { if err := send(client, Tag(1), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) + b.Errorf("send got err %v expected nil", err) } tag, m, err := recv(client, maximumLength, msgRegistry.get) if err != nil { - b.Fatalf("recv got err %v expected nil", err) + b.Errorf("recv got err %v expected nil", err) } if tag != Tag(2) { - b.Fatalf("got tag %v expected 2", tag) + b.Errorf("got tag %v expected 2", tag) } if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %v expected *Rflush", m) + b.Errorf("got message %v expected *Rflush", m) } } } diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index a1b2e0cfe..54e825b28 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package pool provides a trivial integer pool. package pool import ( diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD index bfa1daa10..0377c0876 100644 --- a/pkg/refsvfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -9,7 +9,7 @@ go_template( "refs_template.go", ], opt_consts = [ - "logTrace", + "enableLogging", ], types = [ "T", diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index f64b6c6ae..3fbc91aa5 100644 --- a/pkg/refsvfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -74,11 +74,6 @@ func (r *Refs) LogRefs() bool { return enableLogging } -// EnableLeakCheck enables reference leak checking on r. -func (r *Refs) EnableLeakCheck() { - refsvfs2.Register(r) -} - // ReadRefs returns the current number of references. The returned count is // inherently racy and is unsafe to use without external synchronization. func (r *Refs) ReadRefs() int64 { @@ -136,7 +131,7 @@ func (r *Refs) TryIncRef() bool { func (r *Refs) DecRef(destroy func()) { v := atomic.AddInt64(&r.refCount, -1) if enableLogging { - refsvfs2.LogDecRef(r, v+1) + refsvfs2.LogDecRef(r, v) } switch { case v < 0: @@ -153,6 +148,6 @@ func (r *Refs) DecRef(destroy func()) { func (r *Refs) afterLoad() { if r.ReadRefs() > 0 { - r.EnableLeakCheck() + refsvfs2.Register(r) } } diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go index e7fd30743..7857f5853 100644 --- a/pkg/safemem/block_unsafe.go +++ b/pkg/safemem/block_unsafe.go @@ -68,29 +68,29 @@ func blockFromSlice(slice []byte, needSafecopy bool) Block { } } -// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is +// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+length), which is // safe to access without safecopy. // -// Preconditions: ptr+len does not overflow. -func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, false) +// Preconditions: ptr+length does not overflow. +func BlockFromSafePointer(ptr unsafe.Pointer, length int) Block { + return blockFromPointer(ptr, length, false) } // BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which // is not safe to access without safecopy. // // Preconditions: ptr+len does not overflow. -func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, true) +func BlockFromUnsafePointer(ptr unsafe.Pointer, length int) Block { + return blockFromPointer(ptr, length, true) } -func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block { - if uptr := uintptr(ptr); uptr+uintptr(len) < uptr { - panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len)) +func blockFromPointer(ptr unsafe.Pointer, length int, needSafecopy bool) Block { + if uptr := uintptr(ptr); uptr+uintptr(length) < uptr { + panic(fmt.Sprintf("ptr %#x + len %#x overflows", uptr, length)) } return Block{ start: ptr, - length: len, + length: length, needSafecopy: needSafecopy, } } diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index 752e2dc32..ec17ebc4d 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -79,7 +79,7 @@ func Install(rules SyscallRules) error { // Perform the actual installation. if errno := SetFilter(instrs); errno != 0 { - return fmt.Errorf("Failed to set filter: %v", errno) + return fmt.Errorf("failed to set filter: %v", errno) } log.Infof("Seccomp filters installed.") diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go index 7cd895cc7..652c010da 100644 --- a/pkg/segment/test/set_functions.go +++ b/pkg/segment/test/set_functions.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package segment is a test package. package segment type setFunctions struct{} diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index d75d665ae..dd2effdf9 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint { func (a SyscallArgument) ModeT() uint { return uint(uint16(a.Value)) } + +// ErrFloatingPoint indicates a failed restore due to unusable floating point +// state. +type ErrFloatingPoint struct { + // supported is the supported floating point state. + supported uint64 + + // saved is the saved floating point state. + saved uint64 +} + +// Error returns a sensible description of the restore error. +func (e ErrFloatingPoint) Error() string { + return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) +} diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 19ce99d25..840e53d33 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -17,27 +17,10 @@ package arch import ( - "fmt" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/usermem" ) -// ErrFloatingPoint indicates a failed restore due to unusable floating point -// state. -type ErrFloatingPoint struct { - // supported is the supported floating point state. - supported uint64 - - // saved is the saved floating point state. - saved uint64 -} - -// Error returns a sensible description of the restore error. -func (e ErrFloatingPoint) Error() string { - return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) -} - // XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 // and SSE state, so this is the equivalent XSTATE_BV value. const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index 5138f3bf5..35d2e07c3 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() { } } -// Pid returns the si_pid field. -func (s *SignalInfo) Pid() int32 { +// PID returns the si_pid field. +func (s *SignalInfo) PID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[0:4])) } -// SetPid mutates the si_pid field. -func (s *SignalInfo) SetPid(val int32) { +// SetPID mutates the si_pid field. +func (s *SignalInfo) SetPID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) } -// Uid returns the si_uid field. -func (s *SignalInfo) Uid() int32 { +// UID returns the si_uid field. +func (s *SignalInfo) UID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[4:8])) } -// SetUid mutates the si_uid field. -func (s *SignalInfo) SetUid(val int32) { +// SetUID mutates the si_uid field. +func (s *SignalInfo) SetUID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) } diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 2bf3c45e1..b78e29416 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -15,10 +15,10 @@ package control import ( - "errors" "runtime" "runtime/pprof" "runtime/trace" + "time" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -26,184 +26,253 @@ import ( "gvisor.dev/gvisor/pkg/urpc" ) -var errNoOutput = errors.New("no output writer provided") +// Profile includes profile-related RPC stubs. It provides a way to +// control the built-in runtime profiling facilities. +// +// The profile object must be instantied via NewProfile. +type Profile struct { + // kernel is the kernel under profile. It's immutable. + kernel *kernel.Kernel -// ProfileOpts contains options for the StartCPUProfile/Goroutine RPC call. -type ProfileOpts struct { - // File is the filesystem path for the profile. - File string `json:"path"` + // cpuMu protects CPU profiling. + cpuMu sync.Mutex - // FilePayload is the destination for the profiling output. - urpc.FilePayload + // blockMu protects block profiling. + blockMu sync.Mutex + + // mutexMu protects mutex profiling. + mutexMu sync.Mutex + + // traceMu protects trace profiling. + traceMu sync.Mutex + + // done is closed when profiling is done. + done chan struct{} } -// Profile includes profile-related RPC stubs. It provides a way to -// control the built-in pprof facility in sentry via sentryctl. -// -// The following options to sentryctl are added: +// NewProfile returns a new Profile object, and a stop callback. // -// - collect CPU profile on-demand. -// sentryctl -pid <pid> pprof-cpu-start -// sentryctl -pid <pid> pprof-cpu-stop -// -// - dump out the stack trace of current go routines. -// sentryctl -pid <pid> pprof-goroutine -type Profile struct { - // Kernel is the kernel under profile. It's immutable. - Kernel *kernel.Kernel +// The stop callback should be used at most once. +func NewProfile(k *kernel.Kernel) (*Profile, func()) { + p := &Profile{ + kernel: k, + done: make(chan struct{}), + } + return p, func() { + close(p.done) + } +} - // mu protects the fields below. - mu sync.Mutex +// CPUProfileOpts contains options specifically for CPU profiles. +type CPUProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload - // cpuFile is the current CPU profile output file. - cpuFile *fd.FD + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` - // traceFile is the current execution trace output file. - traceFile *fd.FD + // Hz is the rate, which may be zero. + Hz int `json:"hz"` } -// StartCPUProfile is an RPC stub which starts recording the CPU profile in a -// file. -func (p *Profile) StartCPUProfile(o *ProfileOpts, _ *struct{}) error { +// CPU is an RPC stub which collects a CPU profile. +func (p *Profile) CPU(o *CPUProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } output, err := fd.NewFromFile(o.FilePayload.Files[0]) if err != nil { return err } + defer output.Close() - p.mu.Lock() - defer p.mu.Unlock() + p.cpuMu.Lock() + defer p.cpuMu.Unlock() // Returns an error if profiling is already started. + if o.Hz != 0 { + runtime.SetCPUProfileRate(o.Hz) + } if err := pprof.StartCPUProfile(output); err != nil { - output.Close() return err } + defer pprof.StopCPUProfile() - p.cpuFile = output - return nil -} - -// StopCPUProfile is an RPC stub which stops the CPU profiling and flush out the -// profile data. It takes no argument. -func (p *Profile) StopCPUProfile(_, _ *struct{}) error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.cpuFile == nil { - return errors.New("CPU profiling not started") + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: } - pprof.StopCPUProfile() - p.cpuFile.Close() - p.cpuFile = nil return nil } -// HeapProfile generates a heap profile for the sentry. -func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error { +// HeapProfileOpts contains options specifically for heap profiles. +type HeapProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload +} + +// Heap generates a heap profile. +func (p *Profile) Heap(o *HeapProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() + runtime.GC() // Get up-to-date statistics. - if err := pprof.WriteHeapProfile(output); err != nil { - return err - } - return nil + return pprof.WriteHeapProfile(output) +} + +// GoroutineProfileOpts contains options specifically for goroutine profiles. +type GoroutineProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload } -// GoroutineProfile is an RPC stub which dumps out the stack trace for all -// running goroutines. -func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error { +// Goroutine dumps out the stack trace for all running goroutines. +func (p *Profile) Goroutine(o *GoroutineProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil { - return err - } - return nil + + return pprof.Lookup("goroutine").WriteTo(output, 2) +} + +// BlockProfileOpts contains options specifically for block profiles. +type BlockProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` + + // Rate is the block profile rate. + Rate int `json:"rate"` } -// BlockProfile is an RPC stub which dumps out the stack trace that led to -// blocking on synchronization primitives. -func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error { +// Block dumps a blocking profile. +func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("block").WriteTo(output, 0); err != nil { - return err + + p.blockMu.Lock() + defer p.blockMu.Unlock() + + // Always set the rate. We then wait to collect a profile at this rate, + // and disable when we're done. + rate := 1 + if o.Rate != 0 { + rate = o.Rate } - return nil + runtime.SetBlockProfileRate(rate) + defer runtime.SetBlockProfileRate(0) + + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: + } + + return pprof.Lookup("block").WriteTo(output, 0) +} + +// MutexProfileOpts contains options specifically for mutex profiles. +type MutexProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` + + // Fraction is the mutex profile fraction. + Fraction int `json:"fraction"` } -// MutexProfile is an RPC stub which dumps out the stack trace of holders of -// contended mutexes. -func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error { +// Mutex dumps a mutex profile. +func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil { - return err + + p.mutexMu.Lock() + defer p.mutexMu.Unlock() + + // Always set the fraction. + fraction := 1 + if o.Fraction != 0 { + fraction = o.Fraction } - return nil + runtime.SetMutexProfileFraction(fraction) + defer runtime.SetMutexProfileFraction(0) + + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: + } + + return pprof.Lookup("mutex").WriteTo(output, 0) } -// StartTrace is an RPC stub which starts collection of an execution trace. -func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { +// TraceProfileOpts contains options specifically for traces. +type TraceProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` +} + +// Trace is an RPC stub which starts collection of an execution trace. +func (p *Profile) Trace(o *TraceProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } output, err := fd.NewFromFile(o.FilePayload.Files[0]) if err != nil { return err } + defer output.Close() - p.mu.Lock() - defer p.mu.Unlock() + p.traceMu.Lock() + defer p.traceMu.Unlock() // Returns an error if profiling is already started. if err := trace.Start(output); err != nil { output.Close() return err } + defer trace.Stop() // Ensure all trace contexts are registered. - p.Kernel.RebuildTraceContexts() - - p.traceFile = output - return nil -} - -// StopTrace is an RPC stub which stops collection of an ongoing execution -// trace and flushes the trace data. It takes no argument. -func (p *Profile) StopTrace(_, _ *struct{}) error { - p.mu.Lock() - defer p.mu.Unlock() + p.kernel.RebuildTraceContexts() - if p.traceFile == nil { - return errors.New("Execution tracing not started") + // Wait for the trace. + select { + case <-time.After(o.Duration): + case <-p.done: } // Similarly to the case above, if tasks have not ended traces, we will // lose information. Thus we need to rebuild the tasks in order to have // complete information. This will not lose information if multiple // traces are overlapping. - p.Kernel.RebuildTraceContexts() + p.kernel.RebuildTraceContexts() - trace.Stop() - p.traceFile.Close() - p.traceFile = nil return nil } diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index d800f2c85..62eaca965 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { Callback: func(err error) { if err == nil { log.Infof("Save succeeded: exiting...") + s.Kernel.SetSaveSuccess(false /* autosave */) } else { log.Warningf("Save failed: exiting...") s.Kernel.SetSaveError(err) diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index 314661475..badd5b073 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package fdimport provides the Import function. package fdimport import ( diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index ff2fe6712..8e0aa9019 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err // copyUpBuffers is a buffer pool for copying file content. The buffer // size is the same used by io.Copy. -var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }} +var copyUpBuffers = sync.Pool{ + New: func() interface{} { + b := make([]byte, 8*usermem.PageSize) + return &b + }, +} // copyContentsLocked copies the contents of lower to upper. It panics if // less than size bytes can be copied. @@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in defer lowerFile.DecRef(ctx) // Use a buffer pool to minimize allocations. - buf := copyUpBuffers.Get().([]byte) + buf := copyUpBuffers.Get().(*[]byte) defer copyUpBuffers.Put(buf) // Transfer the contents. @@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in // optimizations could be self-defeating. So we leave this as simple as possible. var offset int64 for { - nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset) + nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset) if err != nil && err != io.EOF { return err } @@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in } return nil } - nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset) + nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset) if err != nil { return err } diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index c7a11eec1..e04784db2 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) { wg.Add(1) go func(o *overlayTestFile) { if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Fatalf("failed to copy up: %v", err) + t.Errorf("failed to copy up: %v", err) } wg.Done() }(file) diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go index 8049538f2..ec3d3f96c 100644 --- a/pkg/sentry/fs/filetest/filetest.go +++ b/pkg/sentry/fs/filetest/filetest.go @@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File { // Read just fails the request. func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Readv not implemented") + return 0, fmt.Errorf("TestFileOperations.Read not implemented") } // Write just fails the request. func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Writev not implemented") + return 0, fmt.Errorf("TestFileOperations.Write not implemented") } diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go index d481baf77..e5579095b 100644 --- a/pkg/sentry/fs/gofer/attr.go +++ b/pkg/sentry/fs/gofer/attr.go @@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType { return fs.BlockDevice case pattr.Mode.IsSocket(): return fs.Socket - case pattr.Mode.IsRegular(): - fallthrough default: return fs.RegularFile } diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 9d6fdd08f..e840b6f5e 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -475,6 +475,9 @@ func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermM func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { switch d.Inode.StableAttr.Type { case fs.Socket: + if i.session().overrides != nil { + return nil, syserror.ENXIO + } return i.getFileSocket(ctx, d, flags) case fs.Pipe: return i.getFilePipe(ctx, d, flags) diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index fbfba1b58..2c14aa6d9 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -276,6 +276,10 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport. // GetFile implements fs.InodeOperations.GetFile. func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + return newFile(ctx, d, flags, i), nil } diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go index 29ff004f2..d0c565879 100644 --- a/pkg/sentry/fs/ramfs/socket.go +++ b/pkg/sentry/fs/ramfs/socket.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +64,7 @@ func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint { // GetFile implements fs.FileOperations.GetFile. func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil + return nil, syserror.ENXIO } // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index e04cd608d..ad4aea282 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -148,6 +148,10 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare // GetFile implements fs.InodeOperations.GetFile. func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + if flags.Write { fsmetric.TmpfsOpensW.Increment() } else if flags.Read { diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 9009ba3c7..4a555bf72 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -200,7 +200,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt } var fd symlinkFD fd.LockFD.Init(&in.locks) - fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) + if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } return &fd.vfsfd, nil default: panic(fmt.Sprintf("unknown inode type: %T", in.impl)) diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 3af807a21..204d8d143 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -129,6 +129,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, syserror.EINVAL } fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + if fuseFDGeneric == nil { + return nil, nil, syserror.EINVAL + } defer fuseFDGeneric.DecRef(ctx) fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD) if !ok { diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index dc0180812..41d679358 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) src = src[2:] } + _ = src // Remove unused warning. } // SizeBytes is the size of the payload of the FUSE_INIT response. diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 435a21d77..36a3f6810 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -31,6 +31,7 @@ import ( fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -499,6 +500,10 @@ func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flag fileDescription: fileDescription{inode: i}, termios: linux.DefaultReplicaTermios, } + if task := kernel.TaskFromContext(ctx); task != nil { + fd.fgProcessGroup = task.ThreadGroup().ProcessGroup() + fd.session = fd.fgProcessGroup.Session() + } fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 0ecb592cf..429733c10 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -164,11 +164,11 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // and write ends of a newly-created pipe, as for pipe(2) and pipe2(2). // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) { +func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) { fs := mnt.Filesystem().Impl().(*filesystem) inode := newInode(ctx, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, inode) defer d.DecRef(ctx) - return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags) + return inode.pipe.ReaderWriterPair(ctx, mnt, d.VFSDentry(), flags) } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index a3780b222..75be6129f 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -57,9 +57,6 @@ func getMM(task *kernel.Task) *mm.MemoryManager { // MemoryManager's users count is incremented, and must be decremented by the // caller when it is no longer in use. func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) { - if task.ExitState() == kernel.TaskExitDead { - return nil, syserror.ESRCH - } var m *mm.MemoryManager task.WithMuLocked(func(t *kernel.Task) { m = t.MemoryManager() @@ -111,9 +108,13 @@ var _ dynamicInode = (*auxvData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -157,9 +158,13 @@ var _ dynamicInode = (*cmdlineData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -472,7 +477,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64 } m, err := getMMIncRef(fd.inode.task) if err != nil { - return 0, nil + return 0, err } defer m.DecUsers(ctx) // Buffer the read data because of MM locks diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 10f1452ef..246bd87bc 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package signalfd provides basic signalfd file implementations. package signalfd import ( @@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 04e7110a3..a4ad625bb 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -163,7 +163,7 @@ afterSymlink: // verifyChildLocked verifies the hash of child against the already verified // hash of the parent to ensure the child is expected. verifyChild triggers a // sentry panic if unexpected modifications to the file system are detected. In -// noCrashOnVerificationFailure mode it returns a syserror instead. +// ErrorOnViolation mode it returns a syserror instead. // // Preconditions: // * fs.renameMu must be locked. diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 5788c661f..a5171b5ad 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -64,6 +64,10 @@ const ( // tree file for "/foo" is "/.merkle.verity.foo". merklePrefix = ".merkle.verity." + // merkleRootPrefix is the prefix of the Merkle tree root file. This + // needs to be different from merklePrefix to avoid name collision. + merkleRootPrefix = ".merkleroot.verity." + // merkleOffsetInParentXattr is the extended attribute name specifying the // offset of the child hash in its parent's Merkle tree. merkleOffsetInParentXattr = "user.merkle.offset" @@ -88,10 +92,8 @@ const ( ) var ( - // noCrashOnVerificationFailure indicates whether the sandbox should panic - // whenever verification fails. If true, an error is returned instead of - // panicking. This should only be set for tests. - noCrashOnVerificationFailure bool + // action specifies the action towards detected violation. + action ViolationAction // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. @@ -102,6 +104,18 @@ var ( // content. type HashAlgorithm int +// ViolationAction is a type specifying the action when an integrity violation +// is detected. +type ViolationAction int + +const ( + // PanicOnViolation terminates the sentry on detected violation. + PanicOnViolation ViolationAction = 0 + // ErrorOnViolation returns an error from the violating system call on + // detected violation. + ErrorOnViolation = 1 +) + // Currently supported hashing algorithms include SHA256 and SHA512. const ( SHA256 HashAlgorithm = iota @@ -166,7 +180,7 @@ type filesystem struct { // its children. So they shouldn't be enabled the same time. This lock // is for the whole file system to ensure that no more than one file is // enabled the same time. - verityMu sync.RWMutex + verityMu sync.RWMutex `state:"nosave"` } // InternalFilesystemOptions may be passed as @@ -196,10 +210,8 @@ type InternalFilesystemOptions struct { // system wrapped by verity file system. LowerGetFSOptions vfs.GetFilesystemOptions - // NoCrashOnVerificationFailure indicates whether the sandbox should - // panic whenever verification fails. If true, an error is returned - // instead of panicking. This should only be set for tests. - NoCrashOnVerificationFailure bool + // Action specifies the action on an integrity violation. + Action ViolationAction } // Name implements vfs.FilesystemType.Name. @@ -211,10 +223,10 @@ func (FilesystemType) Name() string { func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means -// unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +// unexpected modification to the file system is detected. In ErrorOnViolation +// mode, it returns EIO, otherwise it panic. func alertIntegrityViolation(msg string) error { - if noCrashOnVerificationFailure { + if action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -227,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure + action = iopts.Action // Mount the lower file system. The lower file system is wrapped inside // verity, and should not be exposed or connected. @@ -255,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merklePrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -744,7 +756,7 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) // file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The // hash of the generated Merkle tree and the data size is returned. If fd // points to a regular file, the data is the content of the file. If fd points -// to a directory, the data is all hahes of its children, written to the Merkle +// to a directory, the data is all hashes of its children, written to the Merkle // tree file. // // Preconditions: fd.d.fs.verityMu must be locked. diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 6ced0afc9..30d8b4355 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -92,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, - LowerName: "tmpfs", - Alg: hashAlg, - AllowRuntimeEnable: true, - NoCrashOnVerificationFailure: true, + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + Alg: hashAlg, + AllowRuntimeEnable: true, + Action: ErrorOnViolation, }, }, }) @@ -239,6 +239,18 @@ func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, return fd, dataSize, err } +// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD. +func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) { + // Create the file in the underlying file system. + _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) + if err != nil { + return nil, err + } + // Now open the verity file descriptor. + fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) + return fd, err +} + // flipRandomBit randomly flips a bit in the file represented by fd. func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { randomPos := int64(rand.Intn(size)) @@ -349,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } } +// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity +// file succeeds after enabling verity for it. +func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-empty-file" + fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newEmptyFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + enableVerity(ctx, t, fd) + + var buf []byte + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != 0 { + t.Errorf("fd.Read got read length %d, expected 0", n) + } + } +} + // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 2cdcdfc1f..b8627a54f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -214,9 +214,11 @@ type Kernel struct { // netlinkPorts manages allocation of netlink socket port IDs. netlinkPorts *port.Manager - // saveErr is the error causing the sandbox to exit during save, if - // any. It is protected by extMu. - saveErr error `state:"nosave"` + // saveStatus is nil if the sandbox has not been saved, errSaved or + // errAutoSaved if it has been saved successfully, or the error causing the + // sandbox to exit during save. + // It is protected by extMu. + saveStatus error `state:"nosave"` // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` @@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager { return k.netlinkPorts } -// SaveError returns the sandbox error that caused the kernel to exit during -// save. -func (k *Kernel) SaveError() error { +var ( + errSaved = errors.New("sandbox has been successfully saved") + errAutoSaved = errors.New("sandbox has been successfully auto-saved") +) + +// SaveStatus returns the sandbox save status. If it was saved successfully, +// autosaved indicates whether save was triggered by autosave. If it was not +// saved successfully, err indicates the sandbox error that caused the kernel to +// exit during save. +func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) { + k.extMu.Lock() + defer k.extMu.Unlock() + switch k.saveStatus { + case nil: + return false, false, nil + case errSaved: + return true, false, nil + case errAutoSaved: + return true, true, nil + default: + return false, false, k.saveStatus + } +} + +// SetSaveSuccess sets the flag indicating that save completed successfully, if +// no status was already set. +func (k *Kernel) SetSaveSuccess(autosave bool) { k.extMu.Lock() defer k.extMu.Unlock() - return k.saveErr + if k.saveStatus == nil { + if autosave { + k.saveStatus = errAutoSaved + } else { + k.saveStatus = errSaved + } + } } // SetSaveError sets the sandbox error that caused the kernel to exit during @@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error { func (k *Kernel) SetSaveError(err error) { k.extMu.Lock() defer k.extMu.Unlock() - if k.saveErr == nil { - k.saveErr = err + if k.saveStatus == nil { + k.saveStatus = err } } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 7b23cbe86..2d47d2e82 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -63,10 +63,19 @@ func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe { // ReaderWriterPair returns read-only and write-only FDs for vp. // // Preconditions: statusFlags should not contain an open access mode. -func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) { +func (vp *VFSPipe) ReaderWriterPair(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) { // Connected pipes share the same locks. locks := &vfs.FileLocks{} - return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) + r, err := vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks) + if err != nil { + return nil, nil, err + } + w, err := vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) + if err != nil { + r.DecRef(ctx) + return nil, nil, err + } + return r, w, nil } // Allocate implements vfs.FileDescriptionImpl.Allocate. @@ -85,7 +94,10 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s return nil, syserror.EINVAL } - fd := vp.newFD(mnt, vfsd, statusFlags, locks) + fd, err := vp.newFD(mnt, vfsd, statusFlags, locks) + if err != nil { + return nil, err + } // Named pipes have special blocking semantics during open: // @@ -137,16 +149,18 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s } // Preconditions: vp.mu must be held. -func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription { +func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { fd := &VFSPipeFD{ pipe: &vp.pipe, } fd.LockFD.Init(locks) - fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ + if err := fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, UseDentryMetadata: true, - }) + }); err != nil { + return nil, err + } switch { case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable(): @@ -160,7 +174,7 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l panic("invalid pipe flags: must be readable, writable, or both") } - return &fd.vfsfd + return &fd.vfsfd, nil } // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1abfe2201..cef58a590 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) { Signo: int32(linux.SIGTRAP), Code: code, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) if t.beginPtraceStopLocked() { tracer := t.Tracer() tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP)) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 31198d772..db01e4a97 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -52,6 +52,9 @@ type Registry struct { mu sync.Mutex `state:"nosave"` semaphores map[int32]*Set lastIDUsed int32 + // indexes maintains a mapping between a set's index in virtual array and + // its identifier. + indexes map[int32]int32 } // Set represents a set of semaphores that can be operated atomically. @@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { return &Registry{ userNS: userNS, semaphores: make(map[int32]*Set), + indexes: make(map[int32]int32), } } @@ -163,6 +167,9 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu } // Apply system limits. + // + // Map semaphores and map indexes in a registry are of the same size, + // check map semaphores only here for the system limit. if len(r.semaphores) >= setsMax { return nil, syserror.EINVAL } @@ -186,12 +193,43 @@ func (r *Registry) IPCInfo() *linux.SemInfo { SemMsl: linux.SEMMSL, SemOpm: linux.SEMOPM, SemUme: linux.SEMUME, - SemUsz: 0, // SemUsz not supported. + SemUsz: linux.SEMUSZ, SemVmx: linux.SEMVMX, SemAem: linux.SEMAEM, } } +// SemInfo returns a seminfo structure containing the same information as +// for IPC_INFO, except that SemUsz field returns the number of existing +// semaphore sets, and SemAem field returns the number of existing semaphores. +func (r *Registry) SemInfo() *linux.SemInfo { + r.mu.Lock() + defer r.mu.Unlock() + + info := r.IPCInfo() + info.SemUsz = uint32(len(r.semaphores)) + info.SemAem = uint32(r.totalSems()) + + return info +} + +// HighestIndex returns the index of the highest used entry in +// the kernel's array. +func (r *Registry) HighestIndex() int32 { + r.mu.Lock() + defer r.mu.Unlock() + + // By default, highest used index is 0 even though + // there is no semaphroe set. + var highestIndex int32 + for index := range r.indexes { + if index > highestIndex { + highestIndex = index + } + } + return highestIndex +} + // RemoveID removes set with give 'id' from the registry and marks the set as // dead. All waiters will be awakened and fail. func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { @@ -202,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { if set == nil { return syserror.EINVAL } + index, found := r.findIndexByID(id) + if !found { + // Inconsistent state. + panic(fmt.Sprintf("unable to find an index for ID: %d", id)) + } set.mu.Lock() defer set.mu.Unlock() @@ -213,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { } delete(r.semaphores, set.ID) + delete(r.indexes, index) set.destroy() return nil } @@ -236,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File continue } if r.semaphores[id] == nil { + index, found := r.findFirstAvailableIndex() + if !found { + panic("unable to find an available index") + } + r.indexes[index] = id r.lastIDUsed = id r.semaphores[id] = set set.ID = id @@ -254,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set { return r.semaphores[id] } +// FindByIndex looks up a set given an index. +func (r *Registry) FindByIndex(index int32) *Set { + r.mu.Lock() + defer r.mu.Unlock() + + id, present := r.indexes[index] + if !present { + return nil + } + return r.semaphores[id] +} + func (r *Registry) findByKey(key int32) *Set { for _, v := range r.semaphores { if v.key == key { @@ -263,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set { return nil } +func (r *Registry) findIndexByID(id int32) (int32, bool) { + for k, v := range r.indexes { + if v == id { + return k, true + } + } + return 0, false +} + +func (r *Registry) findFirstAvailableIndex() (int32, bool) { + for index := int32(0); index < setsMax; index++ { + if _, present := r.indexes[index]; !present { + return index, true + } + } + return 0, false +} + func (r *Registry) totalSems() int { totalSems := 0 for _, v := range r.semaphores { diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 80a592c8f..073e14507 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -6,6 +6,9 @@ package(licenses = ["notice"]) go_template_instance( name = "shm_refs", out = "shm_refs.go", + consts = { + "enableLogging": "true", + }, package = "shm", prefix = "Shm", template = "//pkg/refsvfs2:refs_template", diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go index e8cce37d0..2488ae7d5 100644 --- a/pkg/sentry/kernel/signal.go +++ b/pkg/sentry/kernel/signal.go @@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 78f718cfe..884966120 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c5137c282..16986244c 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -368,8 +368,8 @@ func (t *Task) exitChildren() { Signo: int32(sig), Code: arch.SignalInfoUser, } - siginfo.SetPid(int32(c.tg.pidns.tids[t])) - siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) + siginfo.SetPID(int32(c.tg.pidns.tids[t])) + siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) c.tg.signalHandlers.mu.Lock() c.sendSignalLocked(siginfo, true /* group */) c.tg.signalHandlers.mu.Unlock() @@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si info := &arch.SignalInfo{ Signo: int32(sig), } - info.SetPid(int32(receiver.tg.pidns.tids[t])) - info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.tids[t])) + info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) if t.exitStatus.Signaled() { info.Code = arch.CLD_KILLED info.SetStatus(int32(t.exitStatus.Signo)) diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 42dd3e278..75af3af79 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { Signo: int32(linux.SIGCHLD), Code: code, } - sigchld.SetPid(int32(t.tg.pidns.tids[target])) - sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + sigchld.SetPID(int32(t.tg.pidns.tids[target])) + sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) sigchld.SetStatus(status) // TODO(b/72102453): Set utime, stime. t.sendSignalLocked(sigchld, true /* group */) @@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState { Signo: int32(sig), Code: t.ptraceCode, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } else { t.ptraceCode = int32(sig) t.ptraceSiginfo = nil @@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { if parent == nil { // Tracer has detached and t was created by Kernel.CreateProcess(). // Pretend the parent is in an ancestor PID + user namespace. - info.SetPid(0) - info.SetUid(int32(auth.OverflowUID)) + info.SetPID(0) + info.SetUID(int32(auth.OverflowUID)) } else { - info.SetPid(int32(t.tg.pidns.tids[parent])) - info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + info.SetPID(int32(t.tg.pidns.tids[parent])) + info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } } t.tg.signalHandlers.mu.Lock() diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 7fd77925f..49e21026e 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp // Translations must be contiguous and in increasing order of // Translation.Source. if i > 0 && ts[i-1].Source.End != t.Source.Start { - return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t) + return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t) } // At least part of each Translation must be required. if t.Source.Intersect(required).Length() == 0 { diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 4c8cd38ed..5ab2ef79f 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -36,12 +36,12 @@ type aioManager struct { contexts map[uint64]*AIOContext } -func (a *aioManager) destroy() { - a.mu.Lock() - defer a.mu.Unlock() +func (mm *MemoryManager) destroyAIOManager(ctx context.Context) { + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() - for _, ctx := range a.contexts { - ctx.destroy() + for id := range mm.aioManager.contexts { + mm.destroyAIOContextLocked(ctx, id) } } @@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool { // be drained. // // Nil is returned if the context does not exist. -func (a *aioManager) destroyAIOContext(id uint64) *AIOContext { - a.mu.Lock() - defer a.mu.Unlock() - ctx, ok := a.contexts[id] +// +// Precondition: mm.aioManager.mu is locked. +func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext { + aioCtx, ok := mm.aioManager.contexts[id] if !ok { return nil } - delete(a.contexts, id) - ctx.destroy() - return ctx + + // Only unmaps after it assured that the address is a valid aio context to + // prevent random memory from been unmapped. + // + // Note: It's possible to unmap this address and map something else into + // the same address. Then it would be unmapping memory that it doesn't own. + // This is, however, the way Linux implements AIO. Keeps the same [weird] + // semantics in case anyone relies on it. + mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) + + delete(mm.aioManager.contexts, id) + aioCtx.destroy() + return aioCtx } // lookupAIOContext looks up the given context. @@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() { } } -// Prepare reserves space for a new request, returning true if available. -// Returns false if the context is busy. -func (ctx *AIOContext) Prepare() bool { +// Prepare reserves space for a new request, returning nil if available. +// Returns EAGAIN if the context is busy and EINVAL if the context is dead. +func (ctx *AIOContext) Prepare() error { ctx.mu.Lock() defer ctx.mu.Unlock() + if ctx.dead { + // Context died after the caller looked it up. + return syserror.EINVAL + } if ctx.outstanding >= ctx.maxOutstanding { - return false + // Context is busy. + return syserror.EAGAIN } ctx.outstanding++ - return true + return nil } // PopRequest pops a completed request if available, this function does not do @@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint // DestroyAIOContext destroys an asynchronous I/O context. It returns the // destroyed context. nil if the context does not exist. func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext { - if _, ok := mm.LookupAIOContext(ctx, id); !ok { + if !mm.isValidAddr(ctx, id) { return nil } - // Only unmaps after it assured that the address is a valid aio context to - // prevent random memory from been unmapped. - // - // Note: It's possible to unmap this address and map something else into - // the same address. Then it would be unmapping memory that it doesn't own. - // This is, however, the way Linux implements AIO. Keeps the same [weird] - // semantics in case anyone relies on it. - mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) - - return mm.aioManager.destroyAIOContext(id) + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() + return mm.destroyAIOContextLocked(ctx, id) } // LookupAIOContext looks up the given context. It returns false if the context @@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC return nil, false } - // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes - // from id). - var buf [4]byte - _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) - if err != nil { + // Protect against 'id' that is inaccessible. + if !mm.isValidAddr(ctx, id) { return nil, false } return aioCtx, true } + +// isValidAddr determines if the address `id` is valid. (Linux also reads 4 +// bytes from id). +func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool { + var buf [4]byte + _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) + return err == nil +} diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go index 3dabac1af..e8931922f 100644 --- a/pkg/sentry/mm/aio_context_state.go +++ b/pkg/sentry/mm/aio_context_state.go @@ -15,6 +15,6 @@ package mm // afterLoad is invoked by stateify. -func (a *AIOContext) afterLoad() { - a.requestReady = make(chan struct{}, 1) +func (ctx *AIOContext) afterLoad() { + ctx.requestReady = make(chan struct{}, 1) } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 09dbc06a4..120707429 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) { panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users)) } - mm.aioManager.destroy() + mm.destroyAIOManager(ctx) mm.metadataMu.Lock() exe := mm.executable diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index acac3d357..bc53bd41e 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) { t.Errorf("CopyOut got %d want 1", n) } } + +// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be +// prepared after destruction. +func TestAIOPrepareAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + defer mm.DecUsers(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + aioCtx, ok := mm.LookupAIOContext(ctx, id) + if !ok { + t.Fatalf("AIOContext not found") + } + mm.DestroyAIOContext(ctx, id) + + // Prepare should fail because aioCtx should be destroyed. + if err := aioCtx.Prepare(); err != syserror.EINVAL { + t.Errorf("aioCtx.Prepare got err %v want nil", err) + } else if err == nil { + aioCtx.CancelPendingRequest() + } +} + +// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be +// looked up after memory manager is destroyed. +func TestAIOLookupAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + mm.DecUsers(ctx) + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + mm.DecUsers(ctx) // This destroys the AIOContext manager. + + if _, ok := mm.LookupAIOContext(ctx, id); ok { + t.Errorf("AIOContext found even after AIOContext manager is destroyed") + } +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 812ab80ef..aacd7ce70 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // facilitate vsyscall emulation. See patchSignalInfo. patchSignalInfo(regs, &c.signalInfo) return false - } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) { + } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) { // The signal was generated by this process. That means // that it was an interrupt or something else that we // should bail for. Note that we ignore signals diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 5d01d21dd..2852b7387 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "arch_genrule", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -39,19 +39,19 @@ go_template_instance( template = ":defs_arm64", ) -genrule( +arch_genrule( name = "entry_impl_amd64", srcs = ["entry_amd64.s"], outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) -genrule( +arch_genrule( name = "entry_impl_arm64", srcs = ["entry_arm64.s"], outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 9742308d8..a9703baf6 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,6 +24,9 @@ go_binary( "defs_impl_arm64.go", "main.go", ], + # Use the libc malloc to avoid any extra dependencies. This is required to + # pass the sentry deps test. + system_malloc = True, visibility = [ "//pkg/sentry/platform/kvm:__pkg__", "//pkg/sentry/platform/ring0:__pkg__", diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 90a7b8392..c05284641 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -53,11 +53,17 @@ func IsCanonical(addr uint64) bool { return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 } +// SwitchToUser performs an eret. +// +// The return value is the exception vector. +// +// +checkescape:all +// //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) if switchOpts.Flush { - FlushTlbAll() + FlushTlbByASID(uintptr(switchOpts.UserASID)) } regs := switchOpts.Registers diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index ef0d8974d..a490bf3af 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -22,19 +22,25 @@ func storeAppASID(asid uintptr) // LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. func LocalFlushTlbAll() -// FlushTlbAll flush all tlb. +// FlushTlbByVA invalidates tlb by VA/Last-level/Inner-Shareable. +func FlushTlbByVA(addr uintptr) + +// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable. +func FlushTlbByASID(asid uintptr) + +// FlushTlbAll invalidates all tlb. func FlushTlbAll() // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) -// FPCR returns the value of FPCR register. +// GetFPCR returns the value of FPCR register. func GetFPCR() (value uintptr) // SetFPCR writes the FPCR value. func SetFPCR(value uintptr) -// FPSR returns the value of FPSR register. +// GetFPSR returns the value of FPSR register. func GetFPSR() (value uintptr) // SetFPSR writes the FPSR value. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 6f4923539..e39b32841 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,23 @@ #include "funcdata.h" #include "textflag.h" +#define TLBI_ASID_SHIFT 48 + +TEXT ·FlushTlbByVA(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + DSB $10 // dsb(ishst) + WORD $0xd50883a1 // tlbi vale1is, x1 + DSB $11 // dsb(ish) + RET + +TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + LSL $TLBI_ASID_SHIFT, R1, R1 + DSB $10 // dsb(ishst) + WORD $0xd5088341 // tlbi aside1is, x1 + DSB $11 // dsb(ish) + RET + TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 DSB $6 // dsb(nshst) WORD $0xd508871f // __tlbi(vmalle1) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index ca16d0381..fb7c5dc61 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -23,7 +23,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..ff6b71802 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -344,18 +343,42 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { } // PackIPPacketInfo packs an IP_PKTINFO socket control message. -func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { - var p linux.ControlMessageIPPacketInfo - p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) - +func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, linux.IP_PKTINFO, t.Arch().Width(), - p, + packetInfo, + ) +} + +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { + var level uint32 + var optType uint32 + switch originalDstAddress.(type) { + case *linux.SockAddrInet: + level = linux.SOL_IP + optType = linux.IP_RECVORIGDSTADDR + case *linux.SockAddrInet6: + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + default: + panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg") + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), originalDstAddress) +} + +// PackSockExtendedErr packs an IP*_RECVERR socket control message. +func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte { + return putCmsgStruct( + buf, + sockErr.CMsgLevel(), + sockErr.CMsgType(), + t.Arch().Width(), + sockErr, ) } @@ -384,7 +407,15 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt } if cmsgs.IP.HasIPPacketInfo { - buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) + } + + if cmsgs.IP.OriginalDstAddress != nil { + buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) + } + + if cmsgs.IP.SockErr != nil { + buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf) } return buf @@ -416,17 +447,19 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageTClass) } - return space -} + if cmsgs.IP.HasIPPacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) + } -// NewIPPacketInfo returns the IPPacketInfo struct. -func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { - var p tcpip.IPPacketInfo - p.NIC = tcpip.NICID(packetInfo.NIC) - copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) - copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + if cmsgs.IP.OriginalDstAddress != nil { + space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) + } - return p + if cmsgs.IP.SockErr != nil { + space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes()) + } + + return space } // Parse parses a raw socket control message into portable objects. @@ -489,6 +522,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.Unix.Credentials = scmCreds i += binary.AlignUp(length, width) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + return socket.ControlMessages{}, syserror.EINVAL + } + cmsgs.IP.HasTimestamp = true + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp) + i += binary.AlignUp(length, width) + default: // Unknown message type. return socket.ControlMessages{}, syserror.EINVAL @@ -512,7 +553,26 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + cmsgs.IP.PacketInfo = packetInfo + i += binary.AlignUp(length, width) + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + + case linux.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg i += binary.AlignUp(length, width) default: @@ -528,6 +588,25 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) i += binary.AlignUp(length, width) + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + + case linux.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..5b868216d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 case linux.SO_LINGER: optlen = syscall.SizeofLinger @@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 } case linux.SOL_TCP: switch name { - case linux.TCP_NODELAY: + case linux.TCP_NODELAY, linux.TCP_INQ: optlen = sizeofInt32 } } @@ -416,68 +416,76 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] return nil } -// RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - // Only allow known and safe flags. - // - // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the - // Socket interface's dependence on netstack. - if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { - return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument - } +func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) { + // We always do a non-blocking recv*(). + sysflags := flags | syscall.MSG_DONTWAIT - var senderAddr linux.SockAddr + msg := syscall.Msghdr{} + if len(iovs) > 0 { + msg.Iov = &iovs[0] + msg.Iovlen = uint64(len(iovs)) + } var senderAddrBuf []byte if senderRequested { senderAddrBuf = make([]byte, sizeofSockaddr) + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(sizeofSockaddr) } - var controlBuf []byte - var msgFlags int - - recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { - // Refuse to do anything if any part of dst.Addrs was unusable. - if uint64(dst.NumBytes()) != dsts.NumBytes() { - return 0, nil - } - if dsts.IsEmpty() { - return 0, nil + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } + n, err := recvmsg(s.fd, &msg, sysflags) + if err != nil { + return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err + } + return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err +} - // We always do a non-blocking recv*(). - sysflags := flags | syscall.MSG_DONTWAIT +// RecvMsg implements socket.Socket.RecvMsg. +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + // Only allow known and safe flags. + if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument + } - iovs := safemem.IovecsFromBlockSeq(dsts) - msg := syscall.Msghdr{ - Iov: &iovs[0], - Iovlen: uint64(len(iovs)), - } - if len(senderAddrBuf) != 0 { - msg.Name = &senderAddrBuf[0] - msg.Namelen = uint32(len(senderAddrBuf)) - } - if controlLen > 0 { - if controlLen > maxControlLen { - controlLen = maxControlLen + var senderAddrBuf []byte + var controlBuf []byte + var msgFlags int + copyToDst := func() (int64, error) { + var n uint64 + var err error + if dst.NumBytes() == 0 { + // We want to make the recvmsg(2) call to the host even if dst is empty + // to fetch control messages, sender address or errors if any occur. + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen) + return int64(n), err + } + + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { + // Refuse to do anything if any part of dst.Addrs was unusable. + if uint64(dst.NumBytes()) != dsts.NumBytes() { + return 0, nil } - controlBuf = make([]byte, controlLen) - msg.Control = &controlBuf[0] - msg.Controllen = controlLen - } - n, err := recvmsg(s.fd, &msg, sysflags) - if err != nil { - return 0, err - } - senderAddrBuf = senderAddrBuf[:msg.Namelen] - msgFlags = int(msg.Flags) - controlLen = uint64(msg.Controllen) - return n, nil - }) + if dsts.IsEmpty() { + return 0, nil + } + + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen) + return n, err + }) + return dst.CopyOutFrom(t, recvmsgToBlocks) + } var ch chan struct{} - n, err := dst.CopyOutFrom(t, recvmsgToBlocks) - if flags&syscall.MSG_DONTWAIT == 0 { + n, err := copyToDst() + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. @@ -494,48 +502,85 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags s.EventRegister(&e, waiter.EventIn) defer s.EventUnregister(&e) } - n, err = dst.CopyOutFrom(t, recvmsgToBlocks) + n, err = copyToDst() } } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + var senderAddr linux.SockAddr if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil +} +func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages { controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { - case syscall.SOL_IP: + case linux.SOL_SOCKET: switch unixCmsg.Header.Type { - case syscall.IP_TOS: + case linux.SO_TIMESTAMP: + controlMessages.IP.HasTimestamp = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp) + } + + case linux.SOL_IP: + switch unixCmsg.Header.Type { + case linux.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) - case syscall.IP_PKTINFO: + case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + controlMessages.IP.PacketInfo = packetInfo + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg } - case syscall.SOL_IPV6: + case linux.SOL_IPV6: switch unixCmsg.Header.Type { - case syscall.IPV6_TCLASS: + case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg + } + + case linux.SOL_TCP: + switch unixCmsg.Header.Type { + case linux.TCP_INQ: + controlMessages.IP.HasInq = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq) } } } - - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil + return controlMessages } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3cc0d4f0f..3f587638f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -320,7 +320,7 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages + readCM socket.IPControlMessages sender tcpip.FullAddress linkPacketInfo tcpip.LinkPacketInfo @@ -408,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = v - s.readCM = cms + s.readCM = socket.NewIPControlMessages(s.family, cms) atomic.StoreUint32(&s.readViewHasData, 1) return nil @@ -428,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { return } - var v tcpip.LingerOption - if err := s.Endpoint.GetSockOpt(&v); err != nil { - return - } - + v := s.Endpoint.SocketOptions().GetLinger() // The case for zero timeout is handled in tcp endpoint close function. // Close is blocked until either: // 1. The endpoint state is not in any of the states: FIN-WAIT1, @@ -965,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.LastError() + err := ep.SocketOptions().GetLastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1046,10 +1042,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return &v, nil case linux.SO_BINDTODEVICE: - var v tcpip.BindToDeviceOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetBindToDevice() if v == 0 { var b primitive.ByteSlice return &b, nil @@ -1092,11 +1085,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.LingerOption var linger linux.Linger - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetLinger() if v.Enabled { linger.OnOff = 1 @@ -1127,13 +1117,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.OutOfBandInlineOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(v) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline())) + return &v, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1417,6 +1402,21 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) return &v, nil + case linux.IPV6_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil + + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1583,6 +1583,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) return &v, nil + case linux.IP_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil + case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) return &v, nil + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { return nil, syserr.ErrInvalidArgument @@ -1785,8 +1801,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } name := string(optVal[:n]) if name == "" { - v := tcpip.BindToDeviceOption(0) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0)) } s := t.NetworkContext() if s == nil { @@ -1794,8 +1809,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } for nicID, nic := range s.Interfaces() { if nic.Name == name { - v := tcpip.BindToDeviceOption(nicID) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID)) } } return syserr.ErrUnknownDevice @@ -1864,8 +1878,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - opt := tcpip.OutOfBandInlineOption(v) - return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) + ep.SocketOptions().SetOutOfBandInline(v != 0) + return nil case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1888,10 +1902,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError( - ep.SetSockOpt(&tcpip.LingerOption{ - Enabled: v.OnOff != 0, - Timeout: time.Second * time.Duration(v.Linger)})) + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger), + }) + return nil case linux.SO_DETACH_FILTER: // optval is ignored. @@ -2094,6 +2109,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2115,6 +2139,16 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name ep.SocketOptions().SetReceiveTClass(v != 0) return nil + case linux.IPV6_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2303,6 +2337,17 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetReceiveTOS(v != 0) return nil + case linux.IP_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil + case linux.IP_PKTINFO: if len(optVal) == 0 { return nil @@ -2325,6 +2370,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetHeaderIncluded(v != 0) return nil + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { return syserr.ErrInvalidArgument @@ -2360,10 +2417,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2437,11 +2492,9 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_MULTICAST_IF, linux.IPV6_MULTICAST_LOOP, linux.IPV6_RECVDSTOPTS, - linux.IPV6_RECVERR, linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2472,7 +2525,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { linux.IP_PKTINFO, linux.IP_PKTOPTIONS, linux.IP_MTU_DISCOVER, - linux.IP_RECVERR, linux.IP_RECVTTL, linux.IP_RECVTOS, linux.IP_MTU, @@ -2701,7 +2753,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // We need to peek beyond the first message. dst = dst.DropFirst(n) num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) { - n, _, err := s.Endpoint.Peek(dsts) + n, err := s.Endpoint.Peek(dsts) // TODO(b/78348848): Handle peek timestamp. if err != nil { return int64(n), syserr.TranslateNetstackError(err).ToError() @@ -2745,15 +2797,19 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ - IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + IP: socket.IPControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasInq: s.readCM.HasInq, + Inq: s.readCM.Inq, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + OriginalDstAddress: s.readCM.OriginalDstAddress, + SockErr: s.readCM.SockErr, }, } } @@ -2770,9 +2826,66 @@ func (s *socketOpsCommon) updateTimestamp() { } } +// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb(). +func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { + so := s.Endpoint.SocketOptions() + err := so.DequeueErr() + if err == nil { + return nil + } + + // Update socket error to reflect ICMP errors in queue. + if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + so.SetLastError(nextErr.Err) + } else if err.ErrOrigin.IsICMPErr() { + so.SetLastError(nil) + } + return err +} + +// addrFamilyFromNetProto returns the address family identifier for the given +// network protocol. +func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { + switch net { + case header.IPv4ProtocolNumber: + return linux.AF_INET + case header.IPv6ProtocolNumber: + return linux.AF_INET6 + default: + panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net)) + } +} + +// recvErr handles MSG_ERRQUEUE for recvmsg(2). +// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error(). +func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + sockErr := s.dequeueErr() + if sockErr == nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + } + + // The payload of the original packet that caused the error is passed as + // normal data via msg_iovec. -- recvmsg(2) + msgFlags := linux.MSG_ERRQUEUE + if int(dst.NumBytes()) < len(sockErr.Payload) { + msgFlags |= linux.MSG_TRUNC + } + n, err := dst.CopyOut(t, sockErr.Payload) + + // The original destination address of the datagram that caused the error is + // supplied via msg_name. -- recvmsg(2) + dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst) + cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})} + return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err) +} + // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { + if flags&linux.MSG_ERRQUEUE != 0 { + return s.recvErr(t, dst) + } + trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 9049e8a21..97729dacc 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -44,7 +44,134 @@ import ( // control messages. type ControlMessages struct { Unix transport.ControlMessages - IP tcpip.ControlMessages + IP IPControlMessages +} + +// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format. +func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + return p +} + +// errOriginToLinux maps tcpip socket origin to Linux socket origin constants. +func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 { + switch origin { + case tcpip.SockExtErrorOriginNone: + return linux.SO_EE_ORIGIN_NONE + case tcpip.SockExtErrorOriginLocal: + return linux.SO_EE_ORIGIN_LOCAL + case tcpip.SockExtErrorOriginICMP: + return linux.SO_EE_ORIGIN_ICMP + case tcpip.SockExtErrorOriginICMP6: + return linux.SO_EE_ORIGIN_ICMP6 + default: + panic(fmt.Sprintf("unknown socket origin: %d", origin)) + } +} + +// sockErrCmsgToLinux converts SockError control message from tcpip format to +// Linux format. +func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { + if sockErr == nil { + return nil + } + + ee := linux.SockExtendedErr{ + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Origin: errOriginToLinux(sockErr.ErrOrigin), + Type: sockErr.ErrType, + Code: sockErr.ErrCode, + Info: sockErr.ErrInfo, + } + + switch sockErr.NetProto { + case header.IPv4ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet) + } + return errMsg + case header.IPv6ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet6) + } + return errMsg + default: + panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto)) + } +} + +// NewIPControlMessages converts the tcpip ControlMessgaes (which does not +// have Linux specific format) to Linux format. +func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { + var orgDstAddr linux.SockAddr + if cmgs.HasOriginalDstAddress { + orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) + } + return IPControlMessages{ + HasTimestamp: cmgs.HasTimestamp, + Timestamp: cmgs.Timestamp, + HasInq: cmgs.HasInq, + Inq: cmgs.Inq, + HasTOS: cmgs.HasTOS, + TOS: cmgs.TOS, + HasTClass: cmgs.HasTClass, + TClass: cmgs.TClass, + HasIPPacketInfo: cmgs.HasIPPacketInfo, + PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + OriginalDstAddress: orgDstAddr, + SockErr: sockErrCmsgToLinux(cmgs.SockErr), + } +} + +// IPControlMessages contains socket control messages for IP sockets. +// This can contain Linux specific structures unlike tcpip.ControlMessages. +// +// +stateify savable +type IPControlMessages struct { + // HasTimestamp indicates whether Timestamp is valid/set. + HasTimestamp bool + + // Timestamp is the time (in ns) that the last packet used to create + // the read data was received. + Timestamp int64 + + // HasInq indicates whether Inq is valid/set. + HasInq bool + + // Inq is the number of bytes ready to be received. + Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS uint8 + + // HasTClass indicates whether TClass is valid/set. + HasTClass bool + + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo linux.ControlMessageIPPacketInfo + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress linux.SockAddr + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr linux.SockErrCMsg } // Release releases Unix domain socket credentials and rights. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0247e93fa..099a56281 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -746,9 +746,6 @@ type baseEndpoint struct { // or may be used if the endpoint is connected. path string - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -840,12 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - e.linger = *v - e.Unlock() - } return nil } @@ -922,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - *o = e.linger - e.Unlock() - return nil - - default: - log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption - } + log.Warningf("Unsupported socket option: %T", opt) + return tcpip.ErrUnknownProtocolOption } // LastError implements Endpoint.LastError. diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index cff442846..b815e498f 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 0bf313a13..c2285f796 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := ctx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := ctx.Prepare(); err != nil { + return err } if eventFile != nil { diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 8db587401..c33571f43 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -175,6 +175,12 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } + file, err := d.Inode.GetFile(t, d, fileFlags) + if err != nil { + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) + } + defer file.DecRef(t) + // Truncate is called when O_TRUNC is specified for any kind of // existing Dirent. Behavior is delegated to the entry's Truncate // implementation. @@ -184,12 +190,6 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } - file, err := d.Inode.GetFile(t, d, fileFlags) - if err != nil { - return syserror.ConvertIntr(err, syserror.ERESTARTSYS) - } - defer file.DecRef(t) - // Success. newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index a1601676f..1166cd7bb 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -150,14 +150,33 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal buf := args[3].Pointer() r := t.IPCNamespace().SemaphoreRegistry() info := r.IPCInfo() - _, err := info.CopyOut(t, buf) - // TODO(gvisor.dev/issue/137): Return the index of the highest used entry. - return 0, nil, err + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + + case linux.SEM_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.SemInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil - case linux.SEM_INFO, - linux.SEM_STAT, - linux.SEM_STAT_ANY: + case linux.SEM_STAT: + arg := args[3].Pointer() + // id is an index in SEM_STAT. + semid, ds, err := semStat(t, id) + if err != nil { + return 0, nil, err + } + if _, err := ds.CopyOut(t, arg); err != nil { + return 0, nil, err + } + return uintptr(semid), nil, err + case linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -202,6 +221,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { return set.GetStat(creds) } +func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByIndex(index) + if set == nil { + return 0, nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + ds, err := set.GetStat(creds) + return set.ID, ds, err +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index e748d33d8..d639c9bf7 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(target.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) + info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) if err := target.SendGroupSignal(info); err != syserror.ESRCH { return 0, nil, err } @@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) err := tg.SendSignal(info) if err == syserror.ESRCH { // ESRCH is ignored because it means the task @@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) // See note above regarding ESRCH race above. if err := tg.SendSignal(info); err != syserror.ESRCH { lastErr = err @@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI Signo: int32(sig), Code: arch.SignalInfoTkill, } - info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9cd052c3d..4adfa6637 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -749,11 +749,6 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 983f8d396..8e7ac0ffe 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal si := arch.SignalInfo{ Signo: int32(linux.SIGCHLD), } - si.SetPid(int32(wr.TID)) - si.SetUid(int32(wr.UID)) + si.SetPID(int32(wr.TID)) + si.SetUID(int32(wr.UID)) // TODO(b/73541790): convert kernel.ExitStatus to functions and make // WaitResult.Status a linux.WaitStatus. s := syscall.WaitStatus(wr.Status) diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 6d0a38330..1365a5a62 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := aioCtx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := aioCtx.Prepare(); err != nil { + return err } if eventFD != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index ee38fdca0..6986e39fe 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -42,7 +42,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return syserror.EINVAL } - r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) + r, w, err := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) + if err != nil { + return err + } defer r.DecRef(t) defer w.DecRef(t) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 7b33b3f59..987012acc 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -752,11 +752,6 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index cb48c37a1..0df023713 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package vfs import ( @@ -41,6 +36,15 @@ type mountKey struct { point unsafe.Pointer // *Dentry } +var ( + mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil)) + mountKeySeed = sync.RandUintptr() +) + +func (k *mountKey) hash() uintptr { + return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed) +} + func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry { } } -func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } - // Invariant: mnt.key.parent == nil. vd.Ok(). func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } -func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } - // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). // // mountTable.Init() must be called on new mountTables before use. -// -// +stateify savable type mountTable struct { // mountTable is implemented as a seqcount-protected hash table that // resolves collisions with linear probing, featuring Robin Hood insertion @@ -84,8 +82,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq sync.SeqCount `state:"nosave"` - seed uint32 // for hashing keys + seq sync.SeqCount `state:"nosave"` // size holds both length (number of elements) and capacity (number of // slots): capacity is stored as its base-2 log (referred to as order) in @@ -150,7 +147,6 @@ func init() { // Init must be called exactly once on each mountTable before use. func (mt *mountTable) Init() { - mt.seed = rand32() mt.size = mtInitOrder mt.slots = newMountTableSlots(mtInitCap) } @@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := key.hash() loop: for { @@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() // We're under the maximum load factor if: // @@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 slots := mt.slots @@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) { off = (off + mountSlotBytes) & offmask } } - -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, seed, s uintptr) uintptr - -//go:linkname rand32 runtime.fastrand -func rand32() uint32 diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 8f070ed53..8998a82dd 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -101,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount { return mounts } +// saveKey is called by stateify. +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // loadMounts is called by stateify. func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { if mounts == nil { @@ -112,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { } } +// loadKey is called by stateify. +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + func (mnt *Mount) afterLoad() { if atomic.LoadInt64(&mnt.refs) != 0 { refsvfs2.Register(mnt) diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go index d462c3eef..e8315326d 100644 --- a/pkg/shim/v1/proc/process.go +++ b/pkg/shim/v1/proc/process.go @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package proc contains process-related utilities. package proc import ( diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD index 05c595bc9..e5b6bf186 100644 --- a/pkg/shim/v1/shim/BUILD +++ b/pkg/shim/v1/shim/BUILD @@ -8,6 +8,7 @@ go_library( "api.go", "platform.go", "service.go", + "shim.go", ], visibility = [ "//pkg/shim:__subpackages__", diff --git a/pkg/shim/v1/shim/shim.go b/pkg/shim/v1/shim/shim.go new file mode 100644 index 000000000..1855a8769 --- /dev/null +++ b/pkg/shim/v1/shim/shim.go @@ -0,0 +1,17 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package shim contains the core containerd shim implementation. +package shim diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go index 07e346654..21e75d16d 100644 --- a/pkg/shim/v1/utils/utils.go +++ b/pkg/shim/v1/utils/utils.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package utils contains utility functions. package utils import ( diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD index f37fefddc..b0e8daa51 100644 --- a/pkg/shim/v2/BUILD +++ b/pkg/shim/v2/BUILD @@ -22,6 +22,7 @@ go_library( "//runsc/specutils", "@com_github_burntsushi_toml//:go_default_library", "@com_github_containerd_cgroups//:go_default_library", + "@com_github_containerd_cgroups//stats/v1:go_default_library", "@com_github_containerd_console//:go_default_library", "@com_github_containerd_containerd//api/events:go_default_library", "@com_github_containerd_containerd//api/types/task:go_default_library", diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go index cba403cae..6aaf5fab8 100644 --- a/pkg/shim/v2/service.go +++ b/pkg/shim/v2/service.go @@ -28,6 +28,7 @@ import ( "github.com/BurntSushi/toml" "github.com/containerd/cgroups" + cgroupsstats "github.com/containerd/cgroups/stats/v1" "github.com/containerd/console" "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/api/types/task" @@ -735,48 +736,48 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. // as runc. // // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 - metrics := &cgroups.Metrics{ - CPU: &cgroups.CPUStat{ - Usage: &cgroups.CPUUsage{ + metrics := &cgroupsstats.Metrics{ + CPU: &cgroupsstats.CPUStat{ + Usage: &cgroupsstats.CPUUsage{ Total: stats.Cpu.Usage.Total, Kernel: stats.Cpu.Usage.Kernel, User: stats.Cpu.Usage.User, PerCPU: stats.Cpu.Usage.Percpu, }, - Throttling: &cgroups.Throttle{ + Throttling: &cgroupsstats.Throttle{ Periods: stats.Cpu.Throttling.Periods, ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods, ThrottledTime: stats.Cpu.Throttling.ThrottledTime, }, }, - Memory: &cgroups.MemoryStat{ + Memory: &cgroupsstats.MemoryStat{ Cache: stats.Memory.Cache, - Usage: &cgroups.MemoryEntry{ + Usage: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Usage.Limit, Usage: stats.Memory.Usage.Usage, Max: stats.Memory.Usage.Max, Failcnt: stats.Memory.Usage.Failcnt, }, - Swap: &cgroups.MemoryEntry{ + Swap: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Swap.Limit, Usage: stats.Memory.Swap.Usage, Max: stats.Memory.Swap.Max, Failcnt: stats.Memory.Swap.Failcnt, }, - Kernel: &cgroups.MemoryEntry{ + Kernel: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Kernel.Limit, Usage: stats.Memory.Kernel.Usage, Max: stats.Memory.Kernel.Max, Failcnt: stats.Memory.Kernel.Failcnt, }, - KernelTCP: &cgroups.MemoryEntry{ + KernelTCP: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.KernelTCP.Limit, Usage: stats.Memory.KernelTCP.Usage, Max: stats.Memory.KernelTCP.Max, Failcnt: stats.Memory.KernelTCP.Failcnt, }, }, - Pids: &cgroups.PidsStat{ + Pids: &cgroupsstats.PidsStat{ Current: stats.Pids.Current, Limit: stats.Pids.Limit, }, diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go index d3931c952..2b1609af0 100644 --- a/pkg/state/tests/integer_test.go +++ b/pkg/state/tests/integer_test.go @@ -20,21 +20,21 @@ import ( ) var ( - allIntTs = []int{-1, 0, 1} - allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} - allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} - allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} - allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} - allUintTs = []uint{0, 1} - allUintptrs = []uintptr{0, 1, ^uintptr(0)} - allUint8s = []uint8{0, 1, math.MaxUint8} - allUint16s = []uint16{0, 1, math.MaxUint16} - allUint32s = []uint32{0, 1, math.MaxUint32} - allUint64s = []uint64{0, 1, math.MaxUint64} + allBasicInts = []int{-1, 0, 1} + allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} + allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} + allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} + allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} + allBasicUints = []uint{0, 1} + allUintptrs = []uintptr{0, 1, ^uintptr(0)} + allUint8s = []uint8{0, 1, math.MaxUint8} + allUint16s = []uint16{0, 1, math.MaxUint16} + allUint32s = []uint32{0, 1, math.MaxUint32} + allUint64s = []uint64{0, 1, math.MaxUint64} ) var allInts = flatten( - allIntTs, + allBasicInts, allInt8s, allInt16s, allInt32s, @@ -42,7 +42,7 @@ var allInts = flatten( ) var allUints = flatten( - allUintTs, + allBasicUints, allUintptrs, allUint8s, allUint16s, diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index be5bc99fc..28e62abbb 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -10,15 +10,34 @@ exports_files(["LICENSE"]) go_template( name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], + srcs = ["generic_atomicptr_unsafe.go"], types = [ "Value", ], ) go_template( + name = "generic_atomicptrmap", + srcs = ["generic_atomicptrmap_unsafe.go"], + opt_consts = [ + "ShardOrder", + ], + opt_types = [ + "Hasher", + ], + types = [ + "Key", + "Value", + ], + deps = [ + ":sync", + "//pkg/gohacks", + ], +) + +go_template( name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], + srcs = ["generic_seqatomic_unsafe.go"], types = [ "Value", ], diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmaptest/BUILD new file mode 100644 index 000000000..3f71ae97d --- /dev/null +++ b/pkg/sync/atomicptrmaptest/BUILD @@ -0,0 +1,57 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], +) + +go_template_instance( + name = "test_atomicptrmap", + out = "test_atomicptrmap_unsafe.go", + package = "atomicptrmap", + prefix = "test", + template = "//pkg/sync:generic_atomicptrmap", + types = { + "Key": "int64", + "Value": "testValue", + }, +) + +go_template_instance( + name = "test_atomicptrmap_sharded", + out = "test_atomicptrmap_sharded_unsafe.go", + consts = { + "ShardOrder": "4", + }, + package = "atomicptrmap", + prefix = "test", + suffix = "Sharded", + template = "//pkg/sync:generic_atomicptrmap", + types = { + "Key": "int64", + "Value": "testValue", + }, +) + +go_library( + name = "atomicptrmap", + testonly = 1, + srcs = [ + "atomicptrmap.go", + "test_atomicptrmap_sharded_unsafe.go", + "test_atomicptrmap_unsafe.go", + ], + deps = [ + "//pkg/gohacks", + "//pkg/sync", + ], +) + +go_test( + name = "atomicptrmap_test", + size = "small", + srcs = ["atomicptrmap_test.go"], + library = ":atomicptrmap", + deps = ["//pkg/sync"], +) diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap.go b/pkg/sync/atomicptrmaptest/atomicptrmap.go new file mode 100644 index 000000000..867821ce9 --- /dev/null +++ b/pkg/sync/atomicptrmaptest/atomicptrmap.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package atomicptrmap instantiates generic_atomicptrmap for testing. +package atomicptrmap + +type testValue struct { + val int +} diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go new file mode 100644 index 000000000..75a9997ef --- /dev/null +++ b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go @@ -0,0 +1,635 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomicptrmap + +import ( + "context" + "fmt" + "math/rand" + "reflect" + "runtime" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +func TestConsistencyWithGoMap(t *testing.T) { + const maxKey = 16 + var vals [4]*testValue + for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { + vals[i] = new(testValue) + } + var ( + m = make(map[int64]*testValue) + apm testAtomicPtrMap + ) + for i := 0; i < 100000; i++ { + // Apply a random operation to both m and apm and expect them to have + // the same result. Bias toward CompareAndSwap, which has the most + // cases; bias away from Range and RangeRepeatable, which are + // relatively expensive. + switch rand.Intn(10) { + case 0, 1: // Load + key := rand.Int63n(maxKey) + want := m[key] + got := apm.Load(key) + t.Logf("Load(%d) = %p", key, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 2, 3: // Swap + key := rand.Int63n(maxKey) + val := vals[rand.Intn(len(vals))] + want := m[key] + if val != nil { + m[key] = val + } else { + delete(m, key) + } + got := apm.Swap(key, val) + t.Logf("Swap(%d, %p) = %p", key, val, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 4, 5, 6, 7: // CompareAndSwap + key := rand.Int63n(maxKey) + oldVal := vals[rand.Intn(len(vals))] + newVal := vals[rand.Intn(len(vals))] + want := m[key] + if want == oldVal { + if newVal != nil { + m[key] = newVal + } else { + delete(m, key) + } + } + got := apm.CompareAndSwap(key, oldVal, newVal) + t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 8: // Range + got := make(map[int64]*testValue) + var ( + haveDup = false + dup int64 + ) + apm.Range(func(key int64, val *testValue) bool { + if _, ok := got[key]; ok && !haveDup { + haveDup = true + dup = key + } + got[key] = val + return true + }) + t.Logf("Range() = %v", got) + if !reflect.DeepEqual(got, m) { + t.Fatalf("got %v, wanted %v", got, m) + } + if haveDup { + t.Fatalf("got duplicate key %d", dup) + } + case 9: // RangeRepeatable + got := make(map[int64]*testValue) + apm.RangeRepeatable(func(key int64, val *testValue) bool { + got[key] = val + return true + }) + t.Logf("RangeRepeatable() = %v", got) + if !reflect.DeepEqual(got, m) { + t.Fatalf("got %v, wanted %v", got, m) + } + } + } +} + +func TestConcurrentHeterogeneous(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var ( + apm testAtomicPtrMap + wg sync.WaitGroup + ) + defer func() { + cancel() + wg.Wait() + }() + + possibleKeyValuePairs := make(map[int64]map[*testValue]struct{}) + addKeyValuePair := func(key int64, val *testValue) { + values := possibleKeyValuePairs[key] + if values == nil { + values = make(map[*testValue]struct{}) + possibleKeyValuePairs[key] = values + } + values[val] = struct{}{} + } + + const numValuesPerKey = 4 + + // These goroutines use keys not used by any other goroutine. + const numPrivateKeys = 3 + for i := 0; i < numPrivateKeys; i++ { + key := int64(i) + var vals [numValuesPerKey]*testValue + for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { + val := new(testValue) + vals[i] = val + addKeyValuePair(key, val) + } + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + var stored *testValue + for ctx.Err() == nil { + switch r.Intn(4) { + case 0: + got := apm.Load(key) + if got != stored { + t.Errorf("Load(%d): got %p, wanted %p", key, got, stored) + return + } + case 1: + val := vals[r.Intn(len(vals))] + want := stored + stored = val + got := apm.Swap(key, val) + if got != want { + t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want) + return + } + case 2, 3: + oldVal := vals[r.Intn(len(vals))] + newVal := vals[r.Intn(len(vals))] + want := stored + if stored == oldVal { + stored = newVal + } + got := apm.CompareAndSwap(key, oldVal, newVal) + if got != want { + t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want) + return + } + } + } + }() + } + + // These goroutines share a small set of keys. + const numSharedKeys = 2 + var ( + sharedKeys [numSharedKeys]int64 + sharedValues = make(map[int64][]*testValue) + sharedValuesSet = make(map[int64]map[*testValue]struct{}) + ) + for i := range sharedKeys { + key := int64(numPrivateKeys + i) + sharedKeys[i] = key + vals := make([]*testValue, numValuesPerKey) + valsSet := make(map[*testValue]struct{}) + for j := range vals { + val := new(testValue) + vals[j] = val + valsSet[val] = struct{}{} + addKeyValuePair(key, val) + } + sharedValues[key] = vals + sharedValuesSet[key] = valsSet + } + randSharedValue := func(r *rand.Rand, key int64) *testValue { + vals := sharedValues[key] + return vals[r.Intn(len(vals))] + } + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + for ctx.Err() == nil { + keyIndex := r.Intn(len(sharedKeys)) + key := sharedKeys[keyIndex] + var ( + op string + got *testValue + ) + switch r.Intn(4) { + case 0: + op = "Load" + got = apm.Load(key) + case 1: + op = "Swap" + got = apm.Swap(key, randSharedValue(r, key)) + case 2, 3: + op = "CompareAndSwap" + got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key)) + } + if got != nil { + valsSet := sharedValuesSet[key] + if _, ok := valsSet[got]; !ok { + t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet) + return + } + } + } + }() + } + + // This goroutine repeatedly searches for unused keys. + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + for ctx.Err() == nil { + key := -1 - r.Int63() + if got := apm.Load(key); got != nil { + t.Errorf("Load(%d): got %p, wanted nil", key, got) + } + } + }() + + // This goroutine repeatedly calls RangeRepeatable() and checks that each + // key corresponds to an expected value. + wg.Add(1) + go func() { + defer wg.Done() + abort := false + for !abort && ctx.Err() == nil { + apm.RangeRepeatable(func(key int64, val *testValue) bool { + values, ok := possibleKeyValuePairs[key] + if !ok { + t.Errorf("RangeRepeatable: got invalid key %d", key) + abort = true + return false + } + if _, ok := values[val]; !ok { + t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values) + abort = true + return false + } + return true + }) + } + }() + + // Finally, the main goroutine spins for the length of the test calling + // Range() and checking that each key that it observes is unique and + // corresponds to an expected value. + seenKeys := make(map[int64]struct{}) + const testDuration = 5 * time.Second + end := time.Now().Add(testDuration) + abort := false + for time.Now().Before(end) { + apm.Range(func(key int64, val *testValue) bool { + values, ok := possibleKeyValuePairs[key] + if !ok { + t.Errorf("Range: got invalid key %d", key) + abort = true + return false + } + if _, ok := values[val]; !ok { + t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values) + abort = true + return false + } + if _, ok := seenKeys[key]; ok { + t.Errorf("Range: got duplicate key %d", key) + abort = true + return false + } + seenKeys[key] = struct{}{} + return true + }) + if abort { + break + } + for k := range seenKeys { + delete(seenKeys, k) + } + } +} + +type benchmarkableMap interface { + Load(key int64) *testValue + Store(key int64, val *testValue) + LoadOrStore(key int64, val *testValue) (*testValue, bool) + Delete(key int64) +} + +// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map. +type rwMutexMap struct { + mu sync.RWMutex + m map[int64]*testValue +} + +func (m *rwMutexMap) Load(key int64) *testValue { + m.mu.RLock() + defer m.mu.RUnlock() + return m.m[key] +} + +func (m *rwMutexMap) Store(key int64, val *testValue) { + m.mu.Lock() + defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[int64]*testValue) + } + m.m[key] = val +} + +func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + m.mu.Lock() + defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[int64]*testValue) + } + if oldVal, ok := m.m[key]; ok { + return oldVal, true + } + m.m[key] = val + return val, false +} + +func (m *rwMutexMap) Delete(key int64) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.m, key) +} + +// syncMap implements benchmarkableMap for a sync.Map. +type syncMap struct { + m sync.Map +} + +func (m *syncMap) Load(key int64) *testValue { + val, ok := m.m.Load(key) + if !ok { + return nil + } + return val.(*testValue) +} + +func (m *syncMap) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + actual, loaded := m.m.LoadOrStore(key, val) + return actual.(*testValue), loaded +} + +func (m *syncMap) Delete(key int64) { + m.m.Delete(key) +} + +// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap. +type benchmarkableAtomicPtrMap struct { + m testAtomicPtrMap +} + +func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue { + return m.m.Load(key) +} + +func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { + return prev, true + } + return val, false +} + +func (m *benchmarkableAtomicPtrMap) Delete(key int64) { + m.m.Store(key, nil) +} + +// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded. +type benchmarkableAtomicPtrMapSharded struct { + m testAtomicPtrMapSharded +} + +func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue { + return m.m.Load(key) +} + +func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { + return prev, true + } + return val, false +} + +func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) { + m.m.Store(key, nil) +} + +var mapImpls = [...]struct { + name string + ctor func() benchmarkableMap +}{ + { + name: "RWMutexMap", + ctor: func() benchmarkableMap { + return new(rwMutexMap) + }, + }, + { + name: "SyncMap", + ctor: func() benchmarkableMap { + return new(syncMap) + }, + }, + { + name: "AtomicPtrMap", + ctor: func() benchmarkableMap { + return new(benchmarkableAtomicPtrMap) + }, + }, + { + name: "AtomicPtrMapSharded", + ctor: func() benchmarkableMap { + return new(benchmarkableAtomicPtrMapSharded) + }, + }, +} + +func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + for i := 0; i < b.N; i++ { + m.Delete(int64(i)) + } +} + +func BenchmarkStoreDelete(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkStoreDelete(b, mapImpl.ctor) + }) + } +} + +func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.LoadOrStore(int64(i), val) + } + for i := 0; i < b.N; i++ { + m.Delete(int64(i)) + } +} + +func BenchmarkLoadOrStoreDelete(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLoadOrStoreDelete(b, mapImpl.ctor) + }) + } +} + +func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Load(int64(i)) + } +} + +func BenchmarkLookupPositive(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLookupPositive(b, mapImpl.ctor) + }) + } +} + +func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Load(int64(-1 - i)) + } +} + +func BenchmarkLookupNegative(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLookupNegative(b, mapImpl.ctor) + }) + } +} + +type benchmarkConcurrentOptions struct { + // loadsPerMutationPair is the number of map lookups between each + // insertion/deletion pair. + loadsPerMutationPair int + + // If changeKeys is true, the keys used by each goroutine change between + // iterations of the test. + changeKeys bool +} + +func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) { + var ( + started sync.WaitGroup + workers sync.WaitGroup + ) + started.Add(1) + + m := mapCtor() + val := &testValue{} + // Insert a large number of unused elements into the map so that used + // elements are distributed throughout memory. + for i := 0; i < 10000; i++ { + m.Store(int64(-1-i), val) + } + // n := ceil(b.N / (opts.loadsPerMutationPair + 2)) + n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2) + for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ { + workerID := i + workers.Add(1) + go func() { + defer workers.Done() + started.Wait() + for i := 0; i < n; i++ { + var key int64 + if opts.changeKeys { + key = int64(workerID*n + i) + } else { + key = int64(workerID) + } + m.LoadOrStore(key, val) + for j := 0; j < opts.loadsPerMutationPair; j++ { + m.Load(key) + } + m.Delete(key) + } + }() + } + + b.ResetTimer() + started.Done() + workers.Wait() +} + +func BenchmarkConcurrent(b *testing.B) { + changeKeysChoices := [...]struct { + name string + val bool + }{ + {"FixedKeys", false}, + {"ChangingKeys", true}, + } + writePcts := [...]struct { + name string + loadsPerMutationPair int + }{ + {"1PercentWrites", 198}, + {"10PercentWrites", 18}, + {"50PercentWrites", 2}, + } + for _, changeKeys := range changeKeysChoices { + for _, writePct := range writePcts { + for _, mapImpl := range mapImpls { + name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name) + b.Run(name, func(b *testing.B) { + benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{ + loadsPerMutationPair: writePct.loadsPerMutationPair, + changeKeys: changeKeys.val, + }) + }) + } + } + } +} diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/generic_atomicptr_unsafe.go index 525c4beed..82b6df18c 100644 --- a/pkg/sync/atomicptr_unsafe.go +++ b/pkg/sync/generic_atomicptr_unsafe.go @@ -3,9 +3,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package template doesn't exist. This file must be instantiated using the +// Package seqatomic doesn't exist. This file must be instantiated using the // go_template_instance rule in tools/go_generics/defs.bzl. -package template +package seqatomic import ( "sync/atomic" diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/generic_atomicptrmap_unsafe.go new file mode 100644 index 000000000..c70dda6dd --- /dev/null +++ b/pkg/sync/generic_atomicptrmap_unsafe.go @@ -0,0 +1,503 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package atomicptrmap doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package atomicptrmap + +import ( + "reflect" + "runtime" + "sync/atomic" + "unsafe" + + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sync" +) + +// Key is a required type parameter. +type Key struct{} + +// Value is a required type parameter. +type Value struct{} + +const ( + // ShardOrder is an optional parameter specifying the base-2 log of the + // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce + // unnecessary synchronization between unrelated concurrent operations, + // improving performance for write-heavy workloads, but increase memory + // usage for small maps. + ShardOrder = 0 +) + +// Hasher is an optional type parameter. If Hasher is provided, it must define +// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps. +type Hasher struct { + defaultHasher +} + +// defaultHasher is the default Hasher. This indirection exists because +// defaultHasher must exist even if a custom Hasher is provided, to prevent the +// Go compiler from complaining about defaultHasher's unused imports. +type defaultHasher struct { + fn func(unsafe.Pointer, uintptr) uintptr + seed uintptr +} + +// Init initializes the Hasher. +func (h *defaultHasher) Init() { + h.fn = sync.MapKeyHasher(map[Key]*Value(nil)) + h.seed = sync.RandUintptr() +} + +// Hash returns the hash value for the given Key. +func (h *defaultHasher) Hash(key Key) uintptr { + return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed) +} + +var hasher Hasher + +func init() { + hasher.Init() +} + +// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are +// safe for concurrent use from multiple goroutines without additional +// synchronization. +// +// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for +// use. AtomicPtrMaps must not be copied after first use. +// +// sync.Map may be faster than AtomicPtrMap if most operations on the map are +// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in +// other circumstances. +type AtomicPtrMap struct { + // AtomicPtrMap is implemented as a hash table with the following + // properties: + // + // * Collisions are resolved with quadratic probing. Of the two major + // alternatives, Robin Hood linear probing makes it difficult for writers + // to execute in parallel, and bucketing is less effective in Go due to + // lack of SIMD. + // + // * The table is optionally divided into shards indexed by hash to further + // reduce unnecessary synchronization. + + shards [1 << ShardOrder]apmShard +} + +func (m *AtomicPtrMap) shard(hash uintptr) *apmShard { + // Go defines right shifts >= width of shifted unsigned operand as 0, so + // this is correct even if ShardOrder is 0 (although nogo complains because + // nogo is dumb). + const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder + index := hash >> indexLSB + return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{})))) +} + +type apmShard struct { + apmShardMutationData + _ [apmShardMutationDataPadding]byte + apmShardLookupData + _ [apmShardLookupDataPadding]byte +} + +type apmShardMutationData struct { + dirtyMu sync.Mutex // serializes slot transitions out of empty + dirty uintptr // # slots with val != nil + count uintptr // # slots with val != nil and val != tombstone() + rehashMu sync.Mutex // serializes rehashing +} + +type apmShardLookupData struct { + seq sync.SeqCount // allows atomic reads of slots+mask + slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq + mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq +} + +const ( + cacheLineBytes = 64 + // Cache line padding is enabled if sharding is. + apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise + // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) % + // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes). + apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1) + apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding + apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1) + apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding + + // These define fractional thresholds for when apmShard.rehash() is called + // (i.e. the load factor) and when it rehases to a larger table + // respectively. They are chosen such that the rehash threshold = the + // expansion threshold + 1/2, so that when reuse of deleted slots is rare + // or non-existent, rehashing occurs after the insertion of at least 1/2 + // the table's size in new entries, which is acceptably infrequent. + apmRehashThresholdNum = 2 + apmRehashThresholdDen = 3 + apmExpansionThresholdNum = 1 + apmExpansionThresholdDen = 6 +) + +type apmSlot struct { + // slot states are indicated by val: + // + // * Empty: val == nil; key is meaningless. May transition to full or + // evacuated with dirtyMu locked. + // + // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val + // is the Value mapped to key. May transition to deleted or evacuated. + // + // * Deleted: val == tombstone(); key is still immutable. key is mapped to + // no Value. May transition to full or evacuated. + // + // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on + // slots that have already been moved, requiring readers to wait for + // rehashing to complete and use the new table. Terminal state. + // + // Note that once val is non-nil, it cannot become nil again. That is, the + // transition from empty to non-empty is irreversible for a given slot; + // the only way to create more empty slots is by rehashing. + val unsafe.Pointer + key Key +} + +func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot { + return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{}))) +} + +var tombstoneObj byte + +func tombstone() unsafe.Pointer { + return unsafe.Pointer(&tombstoneObj) +} + +var evacuatedObj byte + +func evacuated() unsafe.Pointer { + return unsafe.Pointer(&evacuatedObj) +} + +// Load returns the Value stored in m for key. +func (m *AtomicPtrMap) Load(key Key) *Value { + hash := hasher.Hash(key) + shard := m.shard(hash) + +retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + return nil + } + + i := hash & mask + inc := uintptr(1) + for { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil { + // Empty slot; end of probe sequence. + return nil + } + if slotVal == evacuated() { + // Racing with rehashing. + goto retry + } + if slot.key == key { + if slotVal == tombstone() { + return nil + } + return (*Value)(slotVal) + } + i = (i + inc) & mask + inc++ + } +} + +// Store stores the Value val for key. +func (m *AtomicPtrMap) Store(key Key, val *Value) { + m.maybeCompareAndSwap(key, false, nil, val) +} + +// Swap stores the Value val for key and returns the previously-mapped Value. +func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value { + return m.maybeCompareAndSwap(key, false, nil, val) +} + +// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it +// stores the Value newVal for key. CompareAndSwap returns the previous Value +// stored for key, whether or not it stores newVal. +func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value { + return m.maybeCompareAndSwap(key, true, oldVal, newVal) +} + +func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value { + hash := hasher.Hash(key) + shard := m.shard(hash) + oldVal := tombstone() + if typedOldVal != nil { + oldVal = unsafe.Pointer(typedOldVal) + } + newVal := tombstone() + if typedNewVal != nil { + newVal = unsafe.Pointer(typedNewVal) + } + +retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + if (compare && oldVal != tombstone()) || newVal == tombstone() { + return nil + } + // Need to allocate a table before insertion. + shard.rehash(nil) + goto retry + } + + i := hash & mask + inc := uintptr(1) + for { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil { + if (compare && oldVal != tombstone()) || newVal == tombstone() { + return nil + } + // Try to grab this slot for ourselves. + shard.dirtyMu.Lock() + slotVal = atomic.LoadPointer(&slot.val) + if slotVal == nil { + // Check if we need to rehash before dirtying a slot. + if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum { + shard.dirtyMu.Unlock() + shard.rehash(slots) + goto retry + } + slot.key = key + atomic.StorePointer(&slot.val, newVal) // transitions slot to full + shard.dirty++ + atomic.AddUintptr(&shard.count, 1) + shard.dirtyMu.Unlock() + return nil + } + // Raced with another store; the slot is no longer empty. Continue + // with the new value of slotVal since we may have raced with + // another store of key. + shard.dirtyMu.Unlock() + } + if slotVal == evacuated() { + // Racing with rehashing. + goto retry + } + if slot.key == key { + // We're reusing an existing slot, so rehashing isn't necessary. + for { + if (compare && oldVal != slotVal) || newVal == slotVal { + if slotVal == tombstone() { + return nil + } + return (*Value)(slotVal) + } + if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) { + if slotVal == tombstone() { + atomic.AddUintptr(&shard.count, 1) + return nil + } + if newVal == tombstone() { + atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */) + } + return (*Value)(slotVal) + } + slotVal = atomic.LoadPointer(&slot.val) + if slotVal == evacuated() { + goto retry + } + } + } + // This produces a triangular number sequence of offsets from the + // initially-probed position. + i = (i + inc) & mask + inc++ + } +} + +// rehash is marked nosplit to avoid preemption during table copying. +//go:nosplit +func (shard *apmShard) rehash(oldSlots unsafe.Pointer) { + shard.rehashMu.Lock() + defer shard.rehashMu.Unlock() + + if shard.slots != oldSlots { + // Raced with another call to rehash(). + return + } + + // Determine the size of the new table. Constraints: + // + // * The size of the table must be a power of two to ensure that every slot + // is visitable by every probe sequence under quadratic probing with + // triangular numbers. + // + // * The size of the table cannot decrease because even if shard.count is + // currently smaller than shard.dirty, concurrent stores that reuse + // existing slots can drive shard.count back up to a maximum of + // shard.dirty. + newSize := uintptr(8) // arbitrary initial size + if oldSlots != nil { + oldSize := shard.mask + 1 + newSize = oldSize + if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum { + newSize *= 2 + } + } + + // Allocate the new table. + newSlotsSlice := make([]apmSlot, newSize) + newSlotsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&newSlotsSlice)) + newSlots := unsafe.Pointer(newSlotsReflect.Data) + runtime.KeepAlive(newSlotsSlice) + newMask := newSize - 1 + + // Start a writer critical section now so that racing users of the old + // table that observe evacuated() wait for the new table. (But lock dirtyMu + // first since doing so may block, which we don't want to do during the + // writer critical section.) + shard.dirtyMu.Lock() + shard.seq.BeginWrite() + + if oldSlots != nil { + realCount := uintptr(0) + // Copy old entries to the new table. + oldMask := shard.mask + for i := uintptr(0); i <= oldMask; i++ { + oldSlot := apmSlotAt(oldSlots, i) + val := atomic.SwapPointer(&oldSlot.val, evacuated()) + if val == nil || val == tombstone() { + continue + } + hash := hasher.Hash(oldSlot.key) + j := hash & newMask + inc := uintptr(1) + for { + newSlot := apmSlotAt(newSlots, j) + if newSlot.val == nil { + newSlot.val = val + newSlot.key = oldSlot.key + break + } + j = (j + inc) & newMask + inc++ + } + realCount++ + } + // Update dirty to reflect that tombstones were not copied to the new + // table. Use realCount since a concurrent mutator may not have updated + // shard.count yet. + shard.dirty = realCount + } + + // Switch to the new table. + atomic.StorePointer(&shard.slots, newSlots) + atomic.StoreUintptr(&shard.mask, newMask) + + shard.seq.EndWrite() + shard.dirtyMu.Unlock() +} + +// Range invokes f on each Key-Value pair stored in m. If any call to f returns +// false, Range stops iteration and returns. +// +// Range does not necessarily correspond to any consistent snapshot of the +// Map's contents: no Key will be visited more than once, but if the Value for +// any Key is stored or deleted concurrently, Range may reflect any mapping for +// that Key from any point during the Range call. +// +// f must not call other methods on m. +func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) { + for si := 0; si < len(m.shards); si++ { + shard := &m.shards[si] + if !shard.doRange(f) { + return + } + } +} + +func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool { + // We have to lock rehashMu because if we handled races with rehashing by + // retrying, f could see the same key twice. + shard.rehashMu.Lock() + defer shard.rehashMu.Unlock() + slots := shard.slots + if slots == nil { + return true + } + mask := shard.mask + for i := uintptr(0); i <= mask; i++ { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil || slotVal == tombstone() { + continue + } + if !f(slot.key, (*Value)(slotVal)) { + return false + } + } + return true +} + +// RangeRepeatable is like Range, but: +// +// * RangeRepeatable may visit the same Key multiple times in the presence of +// concurrent mutators, possibly passing different Values to f in different +// calls. +// +// * It is safe for f to call other methods on m. +func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) { + for si := 0; si < len(m.shards); si++ { + shard := &m.shards[si] + + retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + continue + } + + for i := uintptr(0); i <= mask; i++ { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == evacuated() { + goto retry + } + if slotVal == nil || slotVal == tombstone() { + continue + } + if !f(slot.key, (*Value)(slotVal)) { + return + } + } + } +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go index 780f3b8f8..82b676abf 100644 --- a/pkg/sync/seqatomic_unsafe.go +++ b/pkg/sync/generic_seqatomic_unsafe.go @@ -3,9 +3,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package template doesn't exist. This file must be instantiated using the +// Package seqatomic doesn't exist. This file must be instantiated using the // go_template_instance rule in tools/go_generics/defs.bzl. -package template +package seqatomic import ( "unsafe" diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go index 7ad6a4434..e925e2e5b 100644 --- a/pkg/sync/runtime_unsafe.go +++ b/pkg/sync/runtime_unsafe.go @@ -11,6 +11,8 @@ package sync import ( + "fmt" + "reflect" "unsafe" ) @@ -61,6 +63,57 @@ const ( TraceEvGoBlockSelect byte = 24 ) +// Rand32 returns a non-cryptographically-secure random uint32. +func Rand32() uint32 { + return fastrand() +} + +// Rand64 returns a non-cryptographically-secure random uint64. +func Rand64() uint64 { + return uint64(fastrand())<<32 | uint64(fastrand()) +} + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 + +// RandUintptr returns a non-cryptographically-secure random uintptr. +func RandUintptr() uintptr { + if unsafe.Sizeof(uintptr(0)) == 4 { + return uintptr(Rand32()) + } + return uintptr(Rand64()) +} + +// MapKeyHasher returns a hash function for pointers of m's key type. +// +// Preconditions: m must be a map. +func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr { + if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map { + panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp)) + } + mtyp := *(**maptype)(unsafe.Pointer(&m)) + return mtyp.hasher +} + +type maptype struct { + size uintptr + ptrdata uintptr + hash uint32 + tflag uint8 + align uint8 + fieldAlign uint8 + kind uint8 + equal func(unsafe.Pointer, unsafe.Pointer) bool + gcdata *byte + str int32 + ptrToThis int32 + key unsafe.Pointer + elem unsafe.Pointer + bucket unsafe.Pointer + hasher func(unsafe.Pointer, uintptr) uintptr + // more fields +} + // These functions are only used within the sync package. //go:linkname semacquire sync.runtime_Semacquire diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go index ce667e825..5ca96d12b 100644 --- a/pkg/sync/rwmutex_test.go +++ b/pkg/sync/rwmutex_test.go @@ -102,7 +102,7 @@ func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone c } for i := 0; i < 100; i++ { } - n = atomic.AddInt32(activity, -1) + atomic.AddInt32(activity, -1) rwm.RUnlock() } cdone <- true diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go index fc6ef60a1..77faa3670 100644 --- a/pkg/syserr/host_linux.go +++ b/pkg/syserr/host_linux.go @@ -32,7 +32,7 @@ var linuxHostTranslations [maxErrno]linuxHostTranslation // FromHost translates a syscall.Errno to a corresponding Error value. func FromHost(err syscall.Errno) *Error { - if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok { + if int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok { panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err)) } return linuxHostTranslations[err].err diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 5ae10939d..77c3c110c 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -15,6 +15,8 @@ package syserr import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -48,45 +50,56 @@ var ( ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM) ) -var netstackErrorTranslations = map[*tcpip.Error]*Error{ - tcpip.ErrUnknownProtocol: ErrUnknownProtocol, - tcpip.ErrUnknownNICID: ErrUnknownNICID, - tcpip.ErrUnknownDevice: ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID: ErrDuplicateNICID, - tcpip.ErrDuplicateAddress: ErrDuplicateAddress, - tcpip.ErrNoRoute: ErrNoRoute, - tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound: ErrAlreadyBound, - tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected: ErrAlreadyConnected, - tcpip.ErrNoPortAvailable: ErrNoPortAvailable, - tcpip.ErrPortInUse: ErrPortInUse, - tcpip.ErrBadLocalAddress: ErrBadLocalAddress, - tcpip.ErrClosedForSend: ErrClosedForSend, - tcpip.ErrClosedForReceive: ErrClosedForReceive, - tcpip.ErrWouldBlock: ErrWouldBlock, - tcpip.ErrConnectionRefused: ErrConnectionRefused, - tcpip.ErrTimeout: ErrTimeout, - tcpip.ErrAborted: ErrAborted, - tcpip.ErrConnectStarted: ErrConnectStarted, - tcpip.ErrDestinationRequired: ErrDestinationRequired, - tcpip.ErrNotSupported: ErrNotSupported, - tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported, - tcpip.ErrNotConnected: ErrNotConnected, - tcpip.ErrConnectionReset: ErrConnectionReset, - tcpip.ErrConnectionAborted: ErrConnectionAborted, - tcpip.ErrNoSuchFile: ErrNoSuchFile, - tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress: ErrHostDown, - tcpip.ErrBadAddress: ErrBadAddress, - tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, - tcpip.ErrMessageTooLong: ErrMessageTooLong, - tcpip.ErrNoBufferSpace: ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, - tcpip.ErrNotPermitted: ErrNotPermittedNet, - tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported, +var netstackErrorTranslations map[string]*Error + +func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) { + key := tcpipErr.String() + if _, ok := netstackErrorTranslations[key]; ok { + panic(fmt.Sprintf("duplicate error key: %s", key)) + } + netstackErrorTranslations[key] = netstackErr +} + +func init() { + netstackErrorTranslations = make(map[string]*Error) + addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol) + addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID) + addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice) + addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption) + addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID) + addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress) + addErrMapping(tcpip.ErrNoRoute, ErrNoRoute) + addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint) + addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound) + addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState) + addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting) + addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected) + addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable) + addErrMapping(tcpip.ErrPortInUse, ErrPortInUse) + addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress) + addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend) + addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive) + addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock) + addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused) + addErrMapping(tcpip.ErrTimeout, ErrTimeout) + addErrMapping(tcpip.ErrAborted, ErrAborted) + addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted) + addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired) + addErrMapping(tcpip.ErrNotSupported, ErrNotSupported) + addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported) + addErrMapping(tcpip.ErrNotConnected, ErrNotConnected) + addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset) + addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted) + addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile) + addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue) + addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown) + addErrMapping(tcpip.ErrBadAddress, ErrBadAddress) + addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable) + addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong) + addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace) + addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled) + addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet) + addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported) } // TranslateNetstackError converts an error from the tcpip package to a sentry @@ -95,7 +108,7 @@ func TranslateNetstackError(err *tcpip.Error) *Error { if err == nil { return nil } - se, ok := netstackErrorTranslations[err] + se, ok := netstackErrorTranslations[err.String()] if !ok { panic("Unknown error: " + err.String()) } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 27f96a3ac..89b765f1b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,10 +1,24 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "sock_err_list", + out = "sock_err_list.go", + package = "tcpip", + prefix = "sockError", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SockError", + "Linker": "*SockError", + }, +) + go_library( name = "tcpip", srcs = [ + "sock_err_list.go", "socketops.go", "tcpip.go", "time_unsafe.go", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index d3ae56ac6..91971b687 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -117,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker { v = ip.TTL() case header.IPv6: v = ip.HopLimit() + case *ipv6HeaderWithExtHdr: + v = ip.HopLimit() + default: + t.Fatalf("unrecognized header type %T for TTL evaluation", ip) } if v != ttl { t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) @@ -321,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress +// field in ControlMessages. +func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasOriginalDstAddress { + t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) + } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { + t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -1400,3 +1417,189 @@ func IGMPGroupAddress(want tcpip.Address) TransportChecker { } } } + +// IPv6ExtHdrChecker is a function to check an extension header. +type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) + +// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. +func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv6 := header.IPv6(b) + if !ipv6.IsValid(len(b)) { + t.Error("not a valid IPv6 packet") + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var rawPayloadHeader header.IPv6RawPayloadHeader + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) + return + } + r, ok := h.(header.IPv6RawPayloadHeader) + if ok { + rawPayloadHeader = r + break + } + } + + networkHeader := ipv6HeaderWithExtHdr{ + IPv6: ipv6, + transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), + payload: rawPayloadHeader.Buf.ToView(), + } + + for _, checker := range checkers { + checker(t, []header.Network{&networkHeader}) + } +} + +// IPv6ExtHdr checks for the presence of extension headers. +// +// All the extension headers in headers will be checked exhaustively in the +// order provided. +func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) + if !ok { + t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), + buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), + ) + + for _, check := range headers { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) + return + } + check(t, h) + } + // Validate we consumed all headers. + // + // The next one over should be a raw payload and then iterator should + // terminate. + wantDone := false + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done != wantDone { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) + return + } + if done { + break + } + if _, ok := h.(header.IPv6RawPayloadHeader); !ok { + t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) + continue + } + wantDone = true + } + } +} + +var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) + +// ipv6HeaderWithExtHdr provides a header.Network implementation that takes +// extension headers into consideration, which is not the case with vanilla +// header.IPv6. +type ipv6HeaderWithExtHdr struct { + header.IPv6 + transport tcpip.TransportProtocolNumber + payload []byte +} + +// TransportProtocol implements header.Network. +func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { + return h.transport +} + +// Payload implements header.Network. +func (h *ipv6HeaderWithExtHdr) Payload() []byte { + return h.payload +} + +// IPv6ExtHdrOptionChecker is a function to check an extension header option. +type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) + +// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop +// extension header and validates the containing options with checkers. +// +// checkers must exhaustively contain all the expected options. +func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { + return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { + t.Helper() + + hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) + if !ok { + t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) + return + } + optionsIterator := hbh.Iter() + for _, f := range checkers { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + f(t, opt) + } + // Validate all options were consumed. + for { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if !done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + if done { + break + } + } + } +} + +// IPv6RouterAlert validates that an extension header option is the RouterAlert +// option and matches on its value. +func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + routerAlert, ok := opt.(*header.IPv6RouterAlertOption) + if !ok { + t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) + return + } + if routerAlert.Value != want { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) + } + } +} diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 309403482..5ab20ee86 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -19,6 +19,7 @@ package header_test import ( "fmt" "math/rand" + "sync" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) { } } } + +func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { + // icmpChecksum should not do any modifications of the header to + // calculate its checksum. Let's call it from a few go-routines and the + // race detector will trigger a warning if there are any concurrent + // read/write accesses. + + const concurrency = 5 + start := make(chan int) + ready := make(chan bool, concurrency) + var wg sync.WaitGroup + wg.Add(concurrency) + defer wg.Wait() + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + ready <- true + <-start + + if got := headerChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + if got := icmpChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + }() + } + for i := 0; i < concurrency; i++ { + <-ready + } + close(start) +} + +func TestICMPv4Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:5]), + buffer.NewViewFromBytes(buf[5:]), + }) + + want := header.Checksum(vv.ToView(), 0) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv4Checksum(h, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} + +func TestICMPv6Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:7]), + buffer.NewViewFromBytes(buf[7:10]), + buffer.NewViewFromBytes(buf[10:]), + }) + + dst := header.IPv6Loopback + src := header.IPv6Loopback + + want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + want = header.Checksum(vv.ToView(), want) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv6Checksum(h, src, dst, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 2f13dea6a..5f9b8e9e2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -199,17 +200,24 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := uint16(0) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } + xsum := ChecksumVV(vv, 0) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 + return ^xsum +} - return xsum +// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when +// a packet having a `net` header causing an ICMP error. +func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { + switch net { + case IPv4ProtocolNumber: + return tcpip.SockExtErrorOriginICMP + case IPv6ProtocolNumber: + return tcpip.SockExtErrorOriginICMP6 + default: + panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) + } } diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 2eef64b4d..eca9750ab 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -265,22 +265,13 @@ func (b ICMPv6) Payload() []byte { // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := Checksum([]byte(src), 0) - xsum = Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = Checksum(upperLayerLength[:], xsum) - xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + + xsum = ChecksumVV(vv, xsum) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) + + return ^xsum } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 55d09355a..5580d6a78 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -18,7 +18,6 @@ import ( "crypto/sha256" "encoding/binary" "fmt" - "strings" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -48,11 +47,13 @@ type IPv6Fields struct { // FlowLabel is the "flow label" field of an IPv6 packet. FlowLabel uint32 - // PayloadLength is the "payload length" field of an IPv6 packet. + // PayloadLength is the "payload length" field of an IPv6 packet, including + // the length of all extension headers. PayloadLength uint16 - // NextHeader is the "next header" field of an IPv6 packet. - NextHeader uint8 + // TransportProtocol is the transport layer protocol number. Serialized in the + // last "next header" field of the IPv6 header + extension headers. + TransportProtocol tcpip.TransportProtocolNumber // HopLimit is the "Hop Limit" field of an IPv6 packet. HopLimit uint8 @@ -62,6 +63,9 @@ type IPv6Fields struct { // DstAddr is the "destination ip address" of an IPv6 packet. DstAddr tcpip.Address + + // ExtensionHeaders are the extension headers following the IPv6 header. + ExtensionHeaders IPv6ExtHdrSerializer } // IPv6 represents an ipv6 header stored in a byte array. @@ -148,13 +152,17 @@ const ( // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the // catch-all or wildcard subnet. That is, all IPv6 addresses are considered to // be contained within this subnet. -var IPv6EmptySubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) - if err != nil { - panic(err) - } - return subnet -}() +var IPv6EmptySubnet = tcpip.AddressWithPrefix{ + Address: IPv6Any, + PrefixLen: 0, +}.Subnet() + +// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined +// by RFC 4291 section 2.5.5. +var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{ + Address: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00", + PrefixLen: 96, +}.Subnet() // IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined // by RFC 4291 section 2.5.6. @@ -253,12 +261,14 @@ func (IPv6) SetChecksum(uint16) { // Encode encodes all the fields of the ipv6 header. func (b IPv6) Encode(i *IPv6Fields) { + extHdr := b[IPv6MinimumSize:] b.SetTOS(i.TrafficClass, i.FlowLabel) b.SetPayloadLength(i.PayloadLength) - b[IPv6NextHeaderOffset] = i.NextHeader b[hopLimit] = i.HopLimit b.SetSourceAddress(i.SrcAddr) b.SetDestinationAddress(i.DstAddr) + nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr) + b[IPv6NextHeaderOffset] = nextHeader } // IsValid performs basic validation on the packet. @@ -286,7 +296,7 @@ func IsV4MappedAddress(addr tcpip.Address) bool { return false } - return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff") + return IPv4MappedIPv6Subnet.Contains(addr) } // IsV6MulticastAddress determines if the provided address is an IPv6 @@ -392,17 +402,6 @@ func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope } -// IsV6UniqueLocalAddress determines if the provided address is an IPv6 -// unique-local address (within the prefix FC00::/7). -func IsV6UniqueLocalAddress(addr tcpip.Address) bool { - if len(addr) != IPv6AddressSize { - return false - } - // According to RFC 4193 section 3.1, a unique local address has the prefix - // FC00::/7. - return (addr[0] & 0xfe) == 0xfc -} - // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier // (IID) to buf as outlined by RFC 7217 and returns the extended buffer. // @@ -449,9 +448,6 @@ const ( // LinkLocalScope indicates a link-local address. LinkLocalScope IPv6AddressScope = iota - // UniqueLocalScope indicates a unique-local address. - UniqueLocalScope - // GlobalScope indicates a global address. GlobalScope ) @@ -469,9 +465,6 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { case IsV6LinkLocalAddress(addr): return LinkLocalScope, nil - case IsV6UniqueLocalAddress(addr): - return UniqueLocalScope, nil - default: return GlobalScope, nil } diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 571eae233..f18981332 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -18,9 +18,12 @@ import ( "bufio" "bytes" "encoding/binary" + "errors" "fmt" "io" + "math" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -75,8 +78,8 @@ const ( // Fragment Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetOffset = 0 - // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to - // discard from the Fragment Offset. + // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment + // Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetShift = 3 // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an @@ -114,6 +117,37 @@ const ( IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 ) +// padIPv6OptionsLength returns the total length for IPv6 options of length l +// considering the 8-octet alignment as stated in RFC 8200 Section 4.2. +func padIPv6OptionsLength(length int) int { + return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1) +} + +// padIPv6Option fills b with the appropriate padding options depending on its +// length. +func padIPv6Option(b []byte) { + switch len(b) { + case 0: // No padding needed. + case 1: // Pad with Pad1. + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier) + default: // Pad with PadN. + s := b[ipv6ExtHdrOptionPayloadOffset:] + for i := range s { + s[i] = 0 + } + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier) + b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s)) + } +} + +// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to +// serialize an option at headerOffset with alignment requirements +// [align]n + alignOffset. +func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int { + padLen := headerOffset - alignOffset + return ((padLen + align - 1) & ^(align - 1)) - padLen +} + // IPv6PayloadHeader is implemented by the various headers that can be found // in an IPv6 payload. // @@ -206,29 +240,55 @@ type IPv6ExtHdrOption interface { isIPv6ExtHdrOption() } -// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier. -type IPv6ExtHdrOptionIndentifier uint8 +// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier. +type IPv6ExtHdrOptionIdentifier uint8 const ( // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that // provides 1 byte padding, as outlined in RFC 8200 section 4.2. - ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0 + ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0 // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that // provides variable length byte padding, as outlined in RFC 8200 section 4.2. - ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1 + ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1 + + // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router + // Alert Hop by Hop option as defined in RFC 2711 section 2.1. + ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5 + + // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header + // option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionTypeOffset = 0 + + // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionLengthOffset = 1 + + // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionPayloadOffset = 2 ) +// ipv6UnknownActionFromIdentifier maps an extension header option's +// identifier's high bits to the action to take when the identifier is unknown. +func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction { + return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) +} + +// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option +// is malformed. +var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option") + // IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension // header option that is unknown by the parsing utilities. type IPv6UnknownExtHdrOption struct { - Identifier IPv6ExtHdrOptionIndentifier + Identifier IPv6ExtHdrOptionIdentifier Data []byte } // UnknownAction implements IPv6OptionUnknownAction.UnknownAction. func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { - return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) + return ipv6UnknownActionFromIdentifier(o.Identifier) } // isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. @@ -251,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // options buffer has been exhausted and we are done iterating. return nil, true, nil } - id := IPv6ExtHdrOptionIndentifier(temp) + id := IPv6ExtHdrOptionIdentifier(temp) // If the option identifier indicates the option is a Pad1 option, then we // know the option does not have Length and Data fields. End processing of @@ -294,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } continue + case ipv6RouterAlertHopByHopOptionIdentifier: + var routerAlertValue [ipv6RouterAlertPayloadLength]byte + if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil { + switch err { + case io.EOF, io.ErrUnexpectedEOF: + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + default: + return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err) + } + } else if n != int(length) { + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + } + return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil default: bytes := make([]byte, length) if n, err := io.ReadFull(&i.reader, bytes); err != nil { @@ -609,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil } + +// IPv6SerializableExtHdr provides serialization for IPv6 extension +// headers. +type IPv6SerializableExtHdr interface { + // identifier returns the assigned IPv6 header identifier for this extension + // header. + identifier() IPv6ExtensionHeaderIdentifier + + // length returns the total serialized length in bytes of this extension + // header, including the common next header and length fields. + length() int + + // serializeInto serializes the receiver into the provided byte + // buffer and with the provided nextHeader value. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto returns the number of bytes that was used to serialize the + // receiver. Implementers must only use the number of bytes required to + // serialize the receiver. Callers MAY provide a larger buffer than required + // to serialize into. + serializeInto(nextHeader uint8, b []byte) int +} + +var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil) + +// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop +// options extension header. +type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption + +const ( + // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field + // in a hop by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrNextHeaderOffset = 0 + + // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop + // by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrLengthOffset = 1 + + // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by + // hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrOptionsOffset = 2 + + // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet + // words in a hop by hop extension header's length field, as stated in RFC + // 8200 section 4.3: + // Length of the Hop-by-Hop Options header in 8-octet units, + // not including the first 8 octets. + ipv6HopByHopExtHdrUnaccountedLenWords = 1 +) + +// identifier implements IPv6SerializableExtHdr. +func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6HopByHopOptionsExtHdrIdentifier +} + +// length implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) length() int { + var total int + for _, opt := range h { + align, alignOffset := opt.alignment() + total += ipv6OptionsAlignmentPadding(total, align, alignOffset) + total += ipv6ExtHdrOptionPayloadOffset + int(opt.length()) + } + // Account for next header and total length fields and add padding. + return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total) +} + +// serializeInto implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int { + optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:] + totalLength := ipv6HopByHopExtHdrOptionsOffset + for _, opt := range h { + // Calculate alignment requirements and pad buffer if necessary. + align, alignOffset := opt.alignment() + padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset) + if padLen != 0 { + padIPv6Option(optBuffer[:padLen]) + totalLength += padLen + optBuffer = optBuffer[padLen:] + } + + l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:]) + optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier()) + optBuffer[ipv6ExtHdrOptionLengthOffset] = l + l += ipv6ExtHdrOptionPayloadOffset + totalLength += int(l) + optBuffer = optBuffer[l:] + } + padded := padIPv6OptionsLength(totalLength) + if padded != totalLength { + padIPv6Option(optBuffer[:padded-totalLength]) + totalLength = padded + } + wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords + if wordsLen > math.MaxUint8 { + panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen)) + } + b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader + b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen) + return totalLength +} + +// IPv6SerializableHopByHopOption provides serialization for hop by hop options. +type IPv6SerializableHopByHopOption interface { + // identifier returns the option identifier of this Hop by Hop option. + identifier() IPv6ExtHdrOptionIdentifier + + // length returns the *payload* size of the option (not considering the type + // and length fields). + length() uint8 + + // alignment returns the alignment requirements from this option. + // + // Alignment requirements take the form [align]n + offset as specified in + // RFC 8200 section 4.2. The alignment requirement is on the offset between + // the option type byte and the start of the hop by hop header. + // + // align must be a power of 2. + alignment() (align int, offset int) + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) uint8 +} + +var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil) + +// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in +// RFC 2711 section 2.1. +type IPv6RouterAlertOption struct { + Value IPv6RouterAlertValue +} + +// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option. +type IPv6RouterAlertValue uint16 + +const ( + // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener + // Discovery message as defined in RFC 2711 section 2.1. + IPv6RouterAlertMLD IPv6RouterAlertValue = 0 + // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as + // defined in RFC 2711 section 2.1. + IPv6RouterAlertRSVP IPv6RouterAlertValue = 1 + // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active + // Networks message as defined in RFC 2711 section 2.1. + IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2 + + // ipv6RouterAlertPayloadLength is the length of the Router Alert payload + // as defined in RFC 2711. + ipv6RouterAlertPayloadLength = 2 + + // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the + // Router Alert option defined as 2n+0 in RFC 2711. + ipv6RouterAlertAlignmentRequirement = 2 + + // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset + // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section + // 2.1. + ipv6RouterAlertAlignmentOffsetRequirement = 0 +) + +// UnknownAction implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction { + return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {} + +// identifier implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier { + return ipv6RouterAlertHopByHopOptionIdentifier +} + +// length implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) length() uint8 { + return ipv6RouterAlertPayloadLength +} + +// alignment implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) alignment() (int, int) { + // From RFC 2711 section 2.1: + // Alignment requirement: 2n+0. + return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 { + binary.BigEndian.PutUint16(b, uint16(o.Value)) + return ipv6RouterAlertPayloadLength +} + +// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers. +type IPv6ExtHdrSerializer []IPv6SerializableExtHdr + +// Serialize serializes the provided list of IPv6 extension headers into b. +// +// Note, b must be of sufficient size to hold all the headers in s. See +// IPv6ExtHdrSerializer.Length for details on the getting the total size of a +// serialized IPv6ExtHdrSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +// +// Serialize takes the transportProtocol value to be used as the last extension +// header's Next Header value and returns the header identifier of the first +// serialized extension header and the total serialized length. +func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) { + nextHeader := uint8(transportProtocol) + if len(s) == 0 { + return nextHeader, 0 + } + var totalLength int + for i, h := range s[:len(s)-1] { + length := h.serializeInto(uint8(s[i+1].identifier()), b) + b = b[length:] + totalLength += length + } + totalLength += s[len(s)-1].serializeInto(nextHeader, b) + return uint8(s[0].identifier()), totalLength +} + +// Length returns the total number of bytes required to serialize the extension +// headers. +func (s IPv6ExtHdrSerializer) Length() int { + var totalLength int + for _, h := range s { + totalLength += h.length() + } + return totalLength +} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go index ab20c5f37..65adc6250 100644 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool func TestIPv6UnknownExtHdrOption(t *testing.T) { tests := []struct { name string - identifier IPv6ExtHdrOptionIndentifier + identifier IPv6ExtHdrOptionIdentifier expectedUnknownAction IPv6OptionUnknownAction }{ { @@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) { bytes: []byte{1, 3}, err: io.ErrUnexpectedEOF, }, + { + name: "Router alert without data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data and Pad1", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with extra data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with missing data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1}, + err: io.ErrUnexpectedEOF, + }, } check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) { @@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) { }) } } + +var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil) + +// dummyHbHOptionSerializer provides a generic implementation of +// IPv6SerializableHopByHopOption for use in tests. +type dummyHbHOptionSerializer struct { + id IPv6ExtHdrOptionIdentifier + payload []byte + align int + alignOffset int +} + +// identifier implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier { + return s.id +} + +// length implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) length() uint8 { + return uint8(len(s.payload)) +} + +// alignment implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) alignment() (int, int) { + align := 1 + if s.align != 0 { + align = s.align + } + return align, s.alignOffset +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 { + return uint8(copy(b, s.payload)) +} + +func TestIPv6HopByHopSerializer(t *testing.T) { + validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + dummy, ok := serializable.(*dummyHbHOptionSerializer) + if !ok { + t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable) + } + unknown, ok := deserialized.(*IPv6UnknownExtHdrOption) + if !ok { + t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{}) + } + if dummy.id != unknown.Identifier { + t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id) + } + if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" { + t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff) + } + } + tests := []struct { + name string + nextHeader uint8 + options []IPv6SerializableHopByHopOption + expect []byte + validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption) + }{ + { + name: "single option", + nextHeader: 13, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 15, + payload: []byte{9, 8, 7, 6}, + }, + }, + expect: []byte{13, 0, 15, 4, 9, 8, 7, 6}, + validate: validateDummies, + }, + { + name: "short option padN zero", + nextHeader: 88, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5}, + }, + }, + expect: []byte{88, 0, 22, 2, 4, 5, 1, 0}, + validate: validateDummies, + }, + { + name: "short option pad1", + nextHeader: 11, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 33, + payload: []byte{1, 2, 3}, + }, + }, + expect: []byte{11, 0, 33, 3, 1, 2, 3, 0}, + validate: validateDummies, + }, + { + name: "long option padN", + nextHeader: 55, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 77, + payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + }, + expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options align 2n", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 2, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0}, + validate: validateDummies, + }, + { + name: "two options align 8n+1", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 8, + alignOffset: 1, + }, + }, + expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0}, + validate: validateDummies, + }, + { + name: "no options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{}, + expect: []byte{33, 0, 1, 4, 0, 0, 0, 0}, + }, + { + name: "Router Alert", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}}, + expect: []byte{33, 0, 5, 2, 0, 0, 1, 0}, + validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + routerAlert, ok := deserialized.(*IPv6RouterAlertOption) + if !ok { + t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized) + } + if routerAlert.Value != IPv6RouterAlertMLD { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD) + } + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6SerializableHopByHopExtHdr(test.options) + length := s.length() + if length != len(test.expect) { + t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect)) + } + b := make([]byte, length) + for i := range b { + // Fill the buffer with ones to ensure all padding is correctly set. + b[i] = 0xFF + } + if got := s.serializeInto(test.nextHeader, b); got != length { + t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length) + } + if diff := cmp.Diff(test.expect, b); diff != "" { + t.Fatalf("serialization mismatch (-want +got):\n%s", diff) + } + + // Deserialize the options and verify them. + optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit + iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter() + for _, testOpt := range test.options { + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + test.validate(t, testOpt, opt) + } + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if !done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + }) + } +} + +var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil) + +// dummyIPv6ExtHdrSerializer provides a generic implementation of +// IPv6SerializableExtHdr for use in tests. +// +// The dummy header always carries the nextHeader value in the first byte. +type dummyIPv6ExtHdrSerializer struct { + id IPv6ExtensionHeaderIdentifier + headerContents []byte +} + +// identifier implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier { + return s.id +} + +// length implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) length() int { + return len(s.headerContents) + 1 +} + +// serializeInto implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int { + b[0] = nextHeader + return copy(b[1:], s.headerContents) + 1 +} + +func TestIPv6ExtHdrSerializer(t *testing.T) { + tests := []struct { + name string + headers []IPv6SerializableExtHdr + nextHeader tcpip.TransportProtocolNumber + expectSerialized []byte + expectNextHeader uint8 + }{ + { + name: "one header", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 15, + headerContents: []byte{1, 2, 3, 4}, + }, + }, + nextHeader: TCPProtocolNumber, + expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4}, + expectNextHeader: 15, + }, + { + name: "two headers", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 22, + headerContents: []byte{1, 2, 3}, + }, + &dummyIPv6ExtHdrSerializer{ + id: 23, + headerContents: []byte{4, 5, 6}, + }, + }, + nextHeader: ICMPv6ProtocolNumber, + expectSerialized: []byte{ + 23, 1, 2, 3, + byte(ICMPv6ProtocolNumber), 4, 5, 6, + }, + expectNextHeader: 22, + }, + { + name: "no headers", + headers: []IPv6SerializableExtHdr{}, + nextHeader: UDPProtocolNumber, + expectSerialized: []byte{}, + expectNextHeader: byte(UDPProtocolNumber), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6ExtHdrSerializer(test.headers) + l := s.Length() + if got, want := l, len(test.expectSerialized); got != want { + t.Fatalf("got serialized length = %d, want = %d", got, want) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with garbage to make sure we're writing to all bytes. + b[i] = 0xFF + } + nextHeader, serializedLen := s.Serialize(test.nextHeader, b) + if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader { + t.Errorf( + "got s.Serialize(..) = (%d, %d), want = (%d, %d)", + nextHeader, + serializedLen, + test.expectNextHeader, + len(test.expectSerialized), + ) + } + if diff := cmp.Diff(test.expectSerialized, b); diff != "" { + t.Errorf("serialization mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go index 018555a26..9d09f32eb 100644 --- a/pkg/tcpip/header/ipv6_fragment.go +++ b/pkg/tcpip/header/ipv6_fragment.go @@ -27,12 +27,11 @@ const ( idV6 = 4 ) -// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the -// fields of a packet that needs to be encoded. -type IPv6FragmentFields struct { - // NextHeader is the "next header" field of an IPv6 fragment. - NextHeader uint8 +var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil) +// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment +// extension header as defined in RFC 8200 section 4.5. +type IPv6SerializableFragmentExtHdr struct { // FragmentOffset is the "fragment offset" field of an IPv6 fragment. FragmentOffset uint16 @@ -43,6 +42,29 @@ type IPv6FragmentFields struct { Identification uint32 } +// identifier implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6FragmentHeader +} + +// length implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) length() int { + return IPv6FragmentHeaderSize +} + +// serializeInto implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int { + // Prevent too many bounds checks. + _ = b[IPv6FragmentHeaderSize:] + binary.BigEndian.PutUint32(b[idV6:], h.Identification) + binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift) + b[nextHdrFrag] = nextHeader + if h.M { + b[more] |= ipv6FragmentExtHdrMFlagMask + } + return IPv6FragmentHeaderSize +} + // IPv6Fragment represents an ipv6 fragment header stored in a byte array. // Most of the methods of IPv6Fragment access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. @@ -58,16 +80,6 @@ const ( IPv6FragmentHeaderSize = 8 ) -// Encode encodes all the fields of the ipv6 fragment. -func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { - b[nextHdrFrag] = i.NextHeader - binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) - if i.M { - b[more] |= 1 - } - binary.BigEndian.PutUint32(b[idV6:], i.Identification) -} - // IsValid performs basic validation on the fragment header. func (b IPv6Fragment) IsValid() bool { return len(b) >= IPv6FragmentHeaderSize diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 426a873b1..e3fbd64f3 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -215,48 +215,6 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { } } -func TestIsV6UniqueLocalAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Unique 1", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Valid Unique 2", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Link Local", - addr: linkLocalAddr, - expected: false, - }, - { - name: "Global", - addr: globalAddr, - expected: false, - }, - { - name: "IPv4", - addr: "\x01\x02\x03\x04", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - func TestIsV6LinkLocalMulticastAddress(t *testing.T) { tests := []struct { name string @@ -346,7 +304,7 @@ func TestScopeForIPv6Address(t *testing.T) { { name: "Unique Local", addr: uniqueLocalAddr1, - scope: header.UniqueLocalScope, + scope: header.GlobalScope, err: nil, }, { diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 0efbfb22b..d9f8e3b35 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -31,7 +31,7 @@ type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO - Route *stack.Route + Route stack.RouteInfo } // Notification is the interface for receiving notification from the packet @@ -230,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket stores outbound packets into the channel. func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } e.q.Write(p) @@ -248,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets stores outbound packets into the channel. func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } if !e.q.Write(p) { diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 9f2084eae..cb94cbea6 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher } switch sa.(type) { case *unix.SockaddrLinklayer: - // enable PACKET_FANOUT mode is the underlying socket is - // of type AF_PACKET. - const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG + // Enable PACKET_FANOUT mode if the underlying socket is of type + // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will + // prevent gvisor from receiving fragmented packets and the host does the + // reassembly on our behalf before delivering the fragments. This makes it + // hard to test fragmentation reassembly code in Netstack. + const fanoutType = unix.PACKET_FANOUT_HASH fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index ce4da7230..a87abc6d6 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -323,9 +323,8 @@ func TestPreserveSrcAddress(t *testing.T) { defer c.cleanup() // Set LocalLinkAddress in route to the value of the bridged address. - r := &stack.Route{ - LocalLinkAddress: baddr, - } + var r stack.Route + r.LocalLinkAddress = baddr r.ResolveWith(raddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -335,7 +334,7 @@ func TestPreserveSrcAddress(t *testing.T) { ReserveHeaderBytes: header.EthernetMinimumSize, Data: buffer.VectorisedView{}, }) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 3e4afcdad..b511d3a31 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -51,7 +51,8 @@ func TestInjectableEndpointDispatch(t *testing.T) { Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) @@ -73,7 +74,8 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { Data: buffer.NewView(0).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 27667f5f0..b7458b620 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -154,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. - newRoute := r.Clone() - pkt.EgressRoute = newRoute + pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -178,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] nxt := pkt.Next() - // Since qdisc can hold onto a packet for long we should Clone - // the route here to ensure it doesn't get released while the - // packet is still in our queue. - newRoute := pkt.EgressRoute.Clone() - pkt.EgressRoute = newRoute if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go index eb5abb906..45adcbccb 100644 --- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { q.mu.Lock() r := q.used < q.limit if r { + s.EgressRoute.Acquire() q.list.PushBack(s) q.used++ } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 7131392cc..dd2e1a125 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -340,9 +340,8 @@ func TestPreserveSrcAddressInSend(t *testing.T) { newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) // Set both remote and local link address in route. - r := stack.Route{ - LocalLinkAddress: newLocalLinkAddress, - } + var r stack.Route + r.LocalLinkAddress = newLocalLinkAddress r.ResolveWith(remoteLinkAddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 8d9a91020..1a2cc39eb 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -263,7 +263,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe fragmentOffset = fragOffset case header.ARPProtocolNumber: - if parse.ARP(pkt) { + if !parse.ARP(pkt) { return } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index a364c5801..bfac358f4 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 { return nil, false } @@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader().View().IsEmpty() { - d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } vv.AppendView(info.Pkt.LinkHeader().View()) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 0fb373612..a25cba513 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -441,9 +441,8 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - } + var r stack.Route + r.NetProto = protocol r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -557,8 +556,8 @@ func TestLinkAddressRequest(t *testing.T) { t.Fatal("expected to send a link address request") } - if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index d8e4a3b54..429af69ee 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -18,7 +18,6 @@ go_template_instance( go_library( name = "fragmentation", srcs = [ - "frag_heap.go", "fragmentation.go", "reassembler.go", "reassembler_list.go", @@ -38,7 +37,6 @@ go_test( name = "fragmentation_test", size = "small", srcs = [ - "frag_heap_test.go", "fragmentation_test.go", "reassembler_test.go", ], diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go deleted file mode 100644 index 0b570d25a..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "fmt" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -type fragment struct { - offset uint16 - vv buffer.VectorisedView -} - -type fragHeap []fragment - -func (h *fragHeap) Len() int { - return len(*h) -} - -func (h *fragHeap) Less(i, j int) bool { - return (*h)[i].offset < (*h)[j].offset -} - -func (h *fragHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *fragHeap) Push(x interface{}) { - *h = append(*h, x.(fragment)) -} - -func (h *fragHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[:n-1] - return x -} - -// reassamble empties the heap and returns a VectorisedView -// containing a reassambled version of the fragments inside the heap. -func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { - curr := heap.Pop(h).(fragment) - views := curr.vv.Views() - size := curr.vv.Size() - - if curr.offset != 0 { - return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) - } - - for h.Len() > 0 { - curr := heap.Pop(h).(fragment) - if int(curr.offset) < size { - curr.vv.TrimFront(size - int(curr.offset)) - } else if int(curr.offset) > size { - return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) - } - size += curr.vv.Size() - views = append(views, curr.vv.Views()...) - } - return buffer.NewVectorisedView(size, views), nil -} diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index d31296a41..1af87d713 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -53,6 +53,10 @@ var ( // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps // with another one. ErrFragmentOverlap = errors.New("overlapping fragments") + + // ErrFragmentConflict indicates that, during reassembly, some fragments are + // in conflict with one another. + ErrFragmentConflict = errors.New("conflicting fragments") ) // FragmentID is the identifier for a fragment. diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 04072d966..9b20bb1d8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -15,9 +15,8 @@ package fragmentation import ( - "container/heap" - "fmt" "math" + "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -29,6 +28,8 @@ type hole struct { first uint16 last uint16 filled bool + final bool + data buffer.View } type reassembler struct { @@ -39,7 +40,6 @@ type reassembler struct { mu sync.Mutex holes []hole filled int - heap fragHeap done bool creationTime int64 pkt *stack.PacketBuffer @@ -48,51 +48,71 @@ type reassembler struct { func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, - holes: make([]hole, 0, 16), - heap: make(fragHeap, 0, 8), creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, last: math.MaxUint16, filled: false, + final: true, }) return r } -// updateHoles updates the list of holes for an incoming fragment. It returns -// true if the fragment fits, it is not a duplicate and it does not overlap with -// another fragment. -// -// For IPv6, overlaps with an existing fragment are explicitly forbidden by -// RFC 8200 section 4.5: -// If any of the fragments being reassembled overlap with any other fragments -// being reassembled for the same packet, reassembly of that packet must be -// abandoned and all the fragments that have been received for that packet -// must be discarded, and no ICMP error messages should be sent. -// -// It is not explicitly forbidden for IPv4, but to keep parity with Linux we -// disallow it as well: -// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 -func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.done { + // A concurrent goroutine might have already reassembled + // the packet and emptied the heap while this goroutine + // was waiting on the mutex. We don't have to do anything in this case. + return buffer.VectorisedView{}, 0, false, 0, nil + } + + var holeFound bool + var consumed int for i := range r.holes { currentHole := &r.holes[i] - if currentHole.filled || last < currentHole.first || currentHole.last < first { + if last < currentHole.first || currentHole.last < first { continue } - + // For IPv6, overlaps with an existing fragment are explicitly forbidden by + // RFC 8200 section 4.5: + // If any of the fragments being reassembled overlap with any other + // fragments being reassembled for the same packet, reassembly of that + // packet must be abandoned and all the fragments that have been received + // for that packet must be discarded, and no ICMP error messages should be + // sent. + // + // It is not explicitly forbidden for IPv4, but to keep parity with Linux we + // disallow it as well: + // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 if first < currentHole.first || currentHole.last < last { // Incoming fragment only partially fits in the free hole. - return false, ErrFragmentOverlap + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + } + if !more { + if !currentHole.final || currentHole.filled && currentHole.last != last { + // We have another final fragment, which does not perfectly overlap. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + } } - r.filled++ + holeFound = true + if currentHole.filled { + // Incoming fragment is a duplicate. + continue + } + + // We are populating the current hole with the payload and creating a new + // hole for any unfilled ranges on either end. if first > currentHole.first { r.holes = append(r.holes, hole{ first: currentHole.first, last: first - 1, filled: false, + final: false, }) } if last < currentHole.last && more { @@ -100,39 +120,22 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) { first: last + 1, last: currentHole.last, filled: false, + final: currentHole.final, }) + currentHole.final = false } + v := pkt.Data.ToOwnedView() + consumed = v.Size() + r.size += consumed // Update the current hole to precisely match the incoming fragment. r.holes[i] = hole{ first: first, last: last, filled: true, + final: currentHole.final, + data: v, } - return true, nil - } - - // Incoming fragment is a duplicate/subset, or its offset comes after the end - // of the reassembled payload. - return false, nil -} - -func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { - r.mu.Lock() - defer r.mu.Unlock() - if r.done { - // A concurrent goroutine might have already reassembled - // the packet and emptied the heap while this goroutine - // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, 0, nil - } - - used, err := r.updateHoles(first, last, more) - if err != nil { - return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err) - } - - var consumed int - if used { + r.filled++ // For IPv6, it is possible to have different Protocol values between // fragments of a packet (because, unlike IPv4, the Protocol is not used to // identify a fragment). In this case, only the Protocol of the first @@ -145,22 +148,30 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s r.pkt = pkt r.proto = proto } - vv := pkt.Data - // We store the incoming packet only if it filled some holes. - heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) - consumed = vv.Size() - r.size += consumed + + break + } + if !holeFound { + // Incoming fragment is beyond end. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict } // Check if all the holes have been filled and we are ready to reassemble. if r.filled < len(r.holes) { return buffer.VectorisedView{}, 0, false, consumed, nil } - res, err := r.heap.reassemble() - if err != nil { - return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err) + + sort.Slice(r.holes, func(i, j int) bool { + return r.holes[i].first < r.holes[j].first + }) + + var size int + views := make([]buffer.View, 0, len(r.holes)) + for _, hole := range r.holes { + views = append(views, hole.data) + size += hole.data.Size() } - return res, r.proto, true, consumed, nil + return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index cee3063b1..2ff03eeeb 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -19,105 +19,156 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -type updateHolesParams struct { +type processParams struct { first uint16 last uint16 more bool - wantUsed bool + pkt *stack.PacketBuffer + wantDone bool wantError error } -func TestUpdateHoles(t *testing.T) { +func TestReassemblerProcess(t *testing.T) { + const proto = 99 + + v := func(size int) buffer.View { + payload := buffer.NewView(size) + for i := 1; i < size; i++ { + payload[i] = uint8(i) * 3 + } + return payload + } + + pkt := func(size int) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v(size).ToVectorisedView(), + }) + } + var tests = []struct { name string - params []updateHolesParams + params []processParams want []hole }{ { name: "No fragments", params: nil, - want: []hole{{first: 0, last: math.MaxUint16, filled: false}}, + want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, }, { name: "One fragment at beginning", - params: []updateHolesParams{{first: 0, last: 1, more: true, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true}, - {first: 2, last: math.MaxUint16, filled: false}, + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: math.MaxUint16, filled: false, final: true}, }, }, { name: "One fragment in the middle", - params: []updateHolesParams{{first: 1, last: 2, more: true, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true}, - {first: 0, last: 0, filled: false}, - {first: 3, last: math.MaxUint16, filled: false}, + {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 0, last: 0, filled: false, final: false}, + {first: 3, last: math.MaxUint16, filled: false, final: true}, }, }, { name: "One fragment at the end", - params: []updateHolesParams{{first: 1, last: 2, more: false, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true}, + {first: 1, last: 2, filled: true, final: true, data: v(2)}, {first: 0, last: 0, filled: false}, }, }, { name: "One fragment completing a packet", - params: []updateHolesParams{{first: 0, last: 1, more: false, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true}, + {first: 0, last: 1, filled: true, final: true, data: v(2)}, }, }, { name: "Two fragments completing a packet", - params: []updateHolesParams{ - {first: 0, last: 1, more: true, wantUsed: true, wantError: nil}, - {first: 2, last: 3, more: false, wantUsed: true, wantError: nil}, + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true}, - {first: 2, last: 3, filled: true}, + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, }, }, { name: "Two fragments completing a packet with a duplicate", - params: []updateHolesParams{ - {first: 0, last: 1, more: true, wantUsed: true, wantError: nil}, - {first: 0, last: 1, more: true, wantUsed: false, wantError: nil}, - {first: 2, last: 3, more: false, wantUsed: true, wantError: nil}, + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, + }, + { + name: "Two fragments completing a packet with a partial duplicate", + params: []processParams{ + {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, + {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true}, - {first: 2, last: 3, filled: true}, + {first: 0, last: 3, filled: true, final: false, data: v(4)}, + {first: 4, last: 5, filled: true, final: true, data: v(2)}, }, }, { name: "Two overlapping fragments", - params: []updateHolesParams{ - {first: 0, last: 10, more: true, wantUsed: true, wantError: nil}, - {first: 5, last: 15, more: false, wantUsed: false, wantError: ErrFragmentOverlap}, - {first: 11, last: 15, more: false, wantUsed: true, wantError: nil}, + params: []processParams{ + {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, + {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, + }, + want: []hole{ + {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 11, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "Two final fragments with different ends", + params: []processParams{ + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 0, last: 9, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, }, want: []hole{ - {first: 0, last: 10, filled: true}, - {first: 11, last: 15, filled: true}, + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, }, }, { - name: "Out of bounds fragment", - params: []updateHolesParams{ - {first: 0, last: 10, more: true, wantUsed: true, wantError: nil}, - {first: 11, last: 15, more: false, wantUsed: true, wantError: nil}, - {first: 16, last: 20, more: false, wantUsed: false, wantError: nil}, + name: "Two final fragments - duplicate, with different ends", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 0, last: 10, filled: true}, - {first: 11, last: 15, filled: true}, + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, }, }, } @@ -126,9 +177,9 @@ func TestUpdateHoles(t *testing.T) { t.Run(test.name, func(t *testing.T) { r := newReassembler(FragmentID{}, &faketime.NullClock{}) for _, param := range test.params { - used, err := r.updateHoles(param.first, param.last, param.more) - if used != param.wantUsed || err != param.wantError { - t.Errorf("got r.updateHoles(%d, %d, %t) = (%t, %v), want = (%t, %v)", param.first, param.last, param.more, used, err, param.wantUsed, param.wantError) + _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if done != param.wantDone || err != param.wantError { + t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) } } if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD index 6ca200b48..ca1247c1e 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/ip/BUILD @@ -18,6 +18,7 @@ go_test( srcs = ["generic_multicast_protocol_test.go"], deps = [ ":ip", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/faketime", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index e308550c4..f2f0e069c 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -30,6 +30,23 @@ type hostState int // The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 // (RFC 2710 section 5). Even though the states are generic across both IGMPv2 // and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. const ( // nonMember is the "'Non-Member' state, when the host does not belong to the // group on the interface. This is the initial state for all memberships on @@ -41,6 +58,15 @@ const ( // but without advertising the membership to the network. nonMember hostState = iota + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + // delayingMember is the "'Delaying Member' state, when the host belongs to // the group on the interface and has a report delay timer running for that // membership." @@ -48,6 +74,16 @@ const ( // 'Delaying Listener' is the MLDv1 term used to describe this state. delayingMember + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + // idleMember is the "Idle Member" state, when the host belongs to the group // on the interface and does not have a report delay timer running for that // membership. @@ -56,6 +92,17 @@ const ( idleMember ) +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + // multicastGroupState holds the Generic Multicast Protocol state for a // multicast group. type multicastGroupState struct { @@ -84,17 +131,6 @@ type multicastGroupState struct { // GenericMulticastProtocolOptions holds options for the generic multicast // protocol. type GenericMulticastProtocolOptions struct { - // Enabled indicates whether the generic multicast protocol will be - // performed. - // - // When enabled, the protocol may transmit report and leave messages when - // joining and leaving multicast groups respectively, and handle incoming - // packets. - // - // When disabled, the protocol will still keep track of locally joined groups, - // it just won't transmit and handle packets, or update groups' state. - Enabled bool - // Rand is the source of random numbers. Rand *rand.Rand @@ -123,8 +159,22 @@ type GenericMulticastProtocolOptions struct { // MulticastGroupProtocol is a multicast group protocol whose core state machine // can be represented by GenericMulticastProtocolState. type MulticastGroupProtocol interface { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled() bool + // SendReport sends a multicast report for the specified group address. - SendReport(groupAddress tcpip.Address) *tcpip.Error + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) *tcpip.Error @@ -138,76 +188,119 @@ type MulticastGroupProtocol interface { // IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state // machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. // +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. +// +// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the +// multicast group protocol is disabled so that leave messages may be sent. type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + opts GenericMulticastProtocolOptions - mu struct { - sync.RWMutex + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState - // memberships holds group addresses and their associated state. - memberships map[tcpip.Address]multicastGroupState - } + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex } // Init initializes the Generic Multicast Protocol state. -func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { - g.mu.Lock() - defer g.mu.Unlock() - g.opts = opts - g.mu.memberships = make(map[tcpip.Address]multicastGroupState) +// +// Must only be called once for the lifetime of g; Init will panic if it is +// called twice. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + + *g = GenericMulticastProtocolState{ + opts: opts, + memberships: make(map[tcpip.Address]multicastGroupState), + protocolMU: protocolMU, + } } -// MakeAllNonMember transitions all groups to the non-member state. +// MakeAllNonMemberLocked transitions all groups to the non-member state. // // The groups will still be considered joined locally. -func (g *GenericMulticastProtocolState) MakeAllNonMember() { - if !g.opts.Enabled { +// +// MUST be called when the multicast group protocol is disabled. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.transitionToNonMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// InitializeGroups initializes each group, as if they were newly joined but -// without affecting the groups' join count. +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. // // Must only be called after calling MakeAllNonMember as a group should not be // initialized while it is not in the non-member state. -func (g *GenericMulticastProtocolState) InitializeGroups() { - if !g.opts.Enabled { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.initializeNewMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// JoinGroup handles joining a new group. +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. // -// If dontInitialize is true, the group will be not be initialized and will be -// left in the non-member state - no packets will be sent for it until it is -// initialized via InitializeGroups. -func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { - g.mu.Lock() - defer g.mu.Unlock() +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} - if info, ok := g.mu.memberships[groupAddress]; ok { +// JoinGroupLocked handles joining a new group. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) { + if info, ok := g.memberships[groupAddress]; ok { // The group has already been joined. info.joins++ - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return } @@ -217,41 +310,43 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do // The state will be updated below, if required. state: nonMember, lastToSendReport: false, - delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { - info, ok := g.mu.memberships[groupAddress] + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + if !g.opts.Protocol.Enabled() { + panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress)) + } + + info, ok := g.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - info.state = idleMember - g.mu.memberships[groupAddress] = info + g.maybeSendDelayedReportLocked(groupAddress, &info) + g.memberships[groupAddress] = info }), } - if !dontInitialize && g.opts.Enabled { + if g.opts.Protocol.Enabled() { g.initializeNewMemberLocked(groupAddress, &info) } - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } -// IsLocallyJoined returns true if the group is locally joined. -func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { - g.mu.RLock() - defer g.mu.RUnlock() - _, ok := g.mu.memberships[groupAddress] +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] return ok } -// LeaveGroup handles leaving the group. +// LeaveGroupLocked handles leaving the group. // // Returns false if the group is not currently joined. -func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { - g.mu.Lock() - defer g.mu.Unlock() - - info, ok := g.mu.memberships[groupAddress] +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] if !ok { return false } @@ -262,30 +357,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b info.joins-- if info.joins != 0 { // If we still have outstanding joins, then do nothing further. - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return true } g.transitionToNonMemberLocked(groupAddress, &info) - delete(g.mu.memberships, groupAddress) + delete(g.memberships, groupAddress) return true } -// HandleQuery handles a query message with the specified maximum response time. +// HandleQueryLocked handles a query message with the specified maximum response +// time. // // If the group address is unspecified, then reports will be scheduled for all // joined groups. // // Report(s) will be scheduled to be sent after a random duration between 0 and // the maximum response time. -func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { - if !g.opts.Enabled { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 2.4 (for IGMPv2), // // In a Membership Query message, the group address field is set to zero @@ -299,28 +394,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, // when sending a Multicast-Address-Specific Query. if groupAddress.Unspecified() { // This is a general query as the group address is unspecified. - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } - } else if info, ok := g.mu.memberships[groupAddress]; ok { + } else if info, ok := g.memberships[groupAddress]; ok { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// HandleReport handles a report message. +// HandleReportLocked handles a report message. // // If the report is for a joined group, any active delayed report will be // cancelled and the host state for the group transitions to idle. -func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { - if !g.opts.Enabled { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), // // If the host receives another host's Report (version 1 or 2) while it has @@ -333,23 +427,23 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } // initializeNewMemberLocked initializes a new group membership. // -// Precondition: g.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { - panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) } - info.state = idleMember + info.lastToSendReport = false if groupAddress == g.opts.AllNodesAddress { // As per RFC 2236 section 6 page 10 (for IGMPv2), @@ -365,9 +459,25 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // case. The node starts in Idle Listener state for that address on // every interface, never transitions to another state, and never sends // a Report or Done for that address. + info.state = idleMember return } + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + // As per RFC 2236 section 3 page 5 (for IGMPv2), // // When a host joins a multicast group, it should immediately transmit an @@ -385,13 +495,35 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // // TODO(gvisor.dev/issue/4901): Support a configurable number of initial // unsolicited reports. - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } } // maybeSendLeave attempts to send a leave message. func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { - if !g.opts.Enabled || !lastToSendReport { + if !g.opts.Protocol.Enabled() || !lastToSendReport { return } @@ -465,7 +597,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres // transitionToNonMemberLocked transitions the given multicast group the the // non-member/listener state. // -// Precondition: e.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state == nonMember { return @@ -479,7 +611,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress // setDelayTimerForAddressRLocked sets timer to send a delay report. // -// Precondition: g.mu MUST be read locked. +// Precondition: g.protocolMU MUST be read locked. func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { if info.state == nonMember { return @@ -517,6 +649,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // TODO: Reset the timer if time remaining is greater than maxResponseTime. return } + info.state = delayingMember info.delayedReportJob.Cancel() info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 670be30d4..85593f211 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -20,6 +20,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/ip" @@ -36,42 +37,178 @@ const ( var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) -type mockMulticastGroupProtocol struct { +type mockMulticastGroupProtocolProtectedFields struct { + sync.RWMutex + + genericMulticastGroup ip.GenericMulticastProtocolState sendReportGroupAddrCount map[tcpip.Address]int sendLeaveGroupAddrCount map[tcpip.Address]int + makeQueuePackets bool + disabled bool } -func (m *mockMulticastGroupProtocol) init() { - m.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +type mockMulticastGroupProtocol struct { + t *testing.T + + mu mockMulticastGroupProtocolProtectedFields } -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { - m.sendReportGroupAddrCount[groupAddress]++ - return nil +func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() + opts.Protocol = m + m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) +} + +func (m *mockMulticastGroupProtocol) initLocked() { + m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +} + +func (m *mockMulticastGroupProtocol) setEnabled(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.disabled = !v +} + +func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.makeQueuePackets = v } +func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.JoinGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleReportLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) +} + +func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) +} + +func (m *mockMulticastGroupProtocol) makeAllNonMember() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.MakeAllNonMemberLocked() +} + +func (m *mockMulticastGroupProtocol) initializeGroups() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.InitializeGroupsLocked() +} + +func (m *mockMulticastGroupProtocol) sendQueuedReports() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.SendQueuedReportsLocked() +} + +// Enabled implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be read locked. +func (m *mockMulticastGroupProtocol) Enabled() bool { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") + } + + return !m.mu.disabled +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + + m.mu.sendReportGroupAddrCount[groupAddress]++ + return !m.mu.makeQueuePackets, nil +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { - m.sendLeaveGroupAddrCount[groupAddress]++ + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + + m.mu.sendLeaveGroupAddrCount[groupAddress]++ return nil } -func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - sendReportGroupAddressesMap := make(map[tcpip.Address]int) +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendReportGroupAddresses { - sendReportGroupAddressesMap[a] = 1 + sendReportGroupAddrCount[a] = 1 } - sendLeaveGroupAddressesMap := make(map[tcpip.Address]int) + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddressesMap[a] = 1 + sendLeaveGroupAddrCount[a] = 1 } - diff := cmp.Diff(mockMulticastGroupProtocol{ - sendReportGroupAddrCount: sendReportGroupAddressesMap, - sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap, - }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{})) - mgp.init() + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + mu: mockMulticastGroupProtocolProtectedFields{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + switch p.Last().String() { + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + return true + } + return false + }, + cmp.Ignore(), + ), + ) + m.initLocked() return diff } @@ -95,36 +232,34 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(0)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr2, }) // Joining a group should send a report immediately and another after // a random interval between 0 and the maximum unsolicited report delay. - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.joinGroup(test.addr) if test.shouldSendReports { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -151,40 +286,42 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(1)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr2, }) - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.joinGroup(test.addr) if test.shouldSendMessages { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Leaving a group should send a leave report immediately and cancel any // delayed reports. - if !g.LeaveGroup(test.addr) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr) + { + + if !mgp.leaveGroup(test.addr) { + t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) + } } if test.shouldSendMessages { - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -226,45 +363,43 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(2)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a report for a group we have a timer scheduled for should // cancel our delayed report timer for the group. - g.HandleReport(test.reportAddr) + mgp.handleReport(test.reportAddr) if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -312,49 +447,47 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a query should make us schedule a new delayed report if it // is a query directed at us or a general query. - g.HandleQuery(test.queryAddr, test.maxDelay) + mgp.handleQuery(test.queryAddr, test.maxDelay) if len(test.expectReportsFor) != 0 { clock.Advance(test.maxDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -362,133 +495,139 @@ func TestHandleQuery(t *testing.T) { } func TestJoinCount(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(4)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: time.Second, }) // Set the join count to 2 for a group. - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } // Only the first join should trigger a report to be sent. - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Group should still be considered joined after leaving once. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } // A leave report should only be sent once the join count reaches 0. - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Leaving once more should actually remove us from the group. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Group should no longer be joined so we should not have anything to // leave. - if g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1) + if mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } func TestMakeAllNonMemberAndInitialize(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should send the leave reports for each but still consider them locally // joined. - g.MakeAllNonMember() - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + mgp.makeAllNonMember() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !g.IsLocallyJoined(group) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group) + if !mgp.isLocallyJoined(group) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) } } // Should send the initial set of unsolcited reports. - g.InitializeGroups() - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.initializeGroups() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } @@ -496,81 +635,172 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { // TestGroupStateNonMember tests that groups do not send packets when in the // non-member state, but are still considered locally joined. func TestGroupStateNonMember(t *testing.T) { - tests := []struct { - name string - enabled bool - dontInitialize bool - }{ - { - name: "Disabled", - enabled: false, - dontInitialize: false, - }, - { - name: "Keep non-member", - enabled: true, - dontInitialize: true, - }, - { - name: "disabled and Keep non-member", - enabled: false, - dontInitialize: true, - }, + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + mgp.setEnabled(false) + + // Joining groups should not send any reports. + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() - clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: test.enabled, - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - Protocol: &mgp, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) + // Receiving a query should not send any reports. + mgp.handleQuery(addr1, time.Nanosecond) + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Nanosecond) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - g.JoinGroup(addr1, test.dontInitialize) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) - } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + // Leaving groups should not send any leave messages. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - g.JoinGroup(addr2, test.dontInitialize) - if !g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2) - } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} - g.HandleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } +func TestQueuedPackets(t *testing.T) { + clock := faketime.NewManualClock() + mgp := mockMulticastGroupProtocol{t: t} + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) - if !g.LeaveGroup(addr2) { - t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2) - } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) - } - if g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2) - } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.setQueuePackets(true) + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.handleReport(addr1) + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.setQueuePackets(true) + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.handleReport(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a314dd386..3005973d7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -344,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -619,11 +619,11 @@ func TestReceive(t *testing.T) { view := buffer.NewView(header.IPv6MinimumSize + payloadLen) ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadLen, - NextHeader: 10, - HopLimit: ipv6.DefaultTTL, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, + PayloadLength: payloadLen, + TransportProtocol: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, }) // Make payload be non-zero. @@ -993,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) { // Create the outer IPv6 header. ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 20, + SrcAddr: outerSrcAddr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -1007,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp.SetIdent(0xdead) icmp.SetSequence(0xbeef) - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - }) - + var extHdrs header.IPv6ExtHdrSerializer // Build the fragmentation header if needed. if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ FragmentOffset: *c.fragmentOffset, M: true, Identification: 0x12345678, }) } + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, + ExtensionHeaders: extHdrs, + }) + // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { view[i] = uint8(i) @@ -1344,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1387,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + // NB: we're lying about transport protocol here to verify the raw + // fragment header bytes. + TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1422,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip).ToVectorisedView() }, @@ -1457,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 0134fadc0..da88d65d1 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -16,7 +16,6 @@ package ipv4 import ( "fmt" - "sync" "sync/atomic" "time" @@ -58,6 +57,9 @@ type IGMPOptions struct { // When enabled, IGMP may transmit IGMP report and leave messages when // joining and leaving multicast groups respectively, and handle incoming // IGMP packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). Enabled bool } @@ -68,8 +70,9 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil) // igmpState.init() MUST be called after creating an IGMP state. type igmpState struct { // The IPv4 endpoint this igmpState is for. - ep *endpoint - opts IGMPOptions + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 @@ -84,20 +87,23 @@ type igmpState struct { // when false. igmpV1Present uint32 - mu struct { - sync.RWMutex - - genericMulticastProtocol ip.GenericMulticastProtocolState + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job +} - // igmpV1Job is scheduled when this interface receives an IGMPv1 style - // message, upon expiration the igmpV1Present flag is cleared. - // igmpV1Job may not be nil once igmpState is initialized. - igmpV1Job *tcpip.Job - } +// Enabled implements ip.MulticastGroupProtocol. +func (igmp *igmpState) Enabled() bool { + // No need to perform IGMP on loopback interfaces since they don't have + // neighbouring nodes. + return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled() } // SendReport implements ip.MulticastGroupProtocol. -func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { igmpType := header.IGMPv2MembershipReport if igmp.v1Present() { igmpType = header.IGMPv1MembershipReport @@ -106,6 +112,8 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // As per RFC 2236 Section 6, Page 8: "If the interface state says the // Querier is running IGMPv1, this action SHOULD be skipped. If the flag @@ -114,18 +122,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { if igmp.v1Present() { return nil } - return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + return err } // init sets up an igmpState struct, and is required to be called before using // a new igmpState. -func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { - igmp.mu.Lock() - defer igmp.mu.Unlock() +// +// Must only be called once for the lifetime of igmp. +func (igmp *igmpState) init(ep *endpoint) { igmp.ep = ep - igmp.opts = opts - igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: igmp, @@ -133,11 +140,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault - igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { igmp.setV1Present(false) }) } +// handleIGMP handles an IGMP packet. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { stats := igmp.ep.protocol.stack.Stats() received := stats.IGMP.PacketsReceived @@ -207,32 +217,34 @@ func (igmp *igmpState) setV1Present(v bool) { } } +// handleMembershipQuery handles a membership query. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero // then change the state to note that an IGMPv1 router is present and // schedule the query received Job. - if maxRespTime == 0 && igmp.opts.Enabled { - igmp.mu.igmpV1Job.Cancel() - igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) + if maxRespTime == 0 && igmp.Enabled() { + igmp.igmpV1Job.Cancel() + igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) igmp.setV1Present(true) maxRespTime = v1MaxRespTime } - igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime) + igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime) } +// handleMembershipReport handles a membership report. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.HandleReport(groupAddress) + igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) } -// writePacket assembles and sends an IGMP packet with the provided fields, -// incrementing the provided stat counter on success. -func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error { +// writePacket assembles and sends an IGMP packet. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) igmpData.SetType(igmpType) igmpData.SetGroupAddress(groupAddress) @@ -243,9 +255,13 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip Data: buffer.View(igmpData).ToVectorisedView(), }) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddr := header.IPv4Any + addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) + if addressEndpoint == nil { + return false, nil + } + localAddr := addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() + addressEndpoint = nil igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.IGMPProtocolNumber, TTL: header.IGMPTTL, @@ -254,22 +270,22 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip &header.IPv4SerializableRouterAlertOption{}, }) - sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() - return err + sentStats.Dropped.Increment() + return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sent.V1MembershipReport.Increment() + sentStats.V1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sent.V2MembershipReport.Increment() + sentStats.V2MembershipReport.Increment() case header.IGMPLeaveGroup: - sent.LeaveGroup.Increment() + sentStats.LeaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } - return nil + return true, nil } // joinGroup handles adding a new group to the membership map, setting up the @@ -278,28 +294,27 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // // If the group already exists in the membership map, returns // tcpip.ErrDuplicateAddress. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { - igmp.mu.Lock() - defer igmp.mu.Unlock() - return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Leave Group message // if required. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // LeaveGroup returns false only if the group was not joined. - if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) { + if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -308,16 +323,23 @@ func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of IGMP, but remains // joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) softLeaveAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.MakeAllNonMember() + igmp.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the IGMP state for each group that has // been joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) initializeAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.InitializeGroups() + igmp.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) sendQueuedReports() { + igmp.genericMulticastProtocol.SendQueuedReportsLocked() } diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 5e139377b..1ee573ac8 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -16,6 +16,7 @@ package ipv4_test import ( "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -29,6 +30,7 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + addr = tcpip.Address("\x0a\x00\x00\x01") multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 ) @@ -41,6 +43,7 @@ func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -71,7 +74,6 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - return e, s, clock } @@ -104,6 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // reports for backwards compatibility. func TestIgmpV1Present(t *testing.T) { e, s, clock := createStack(t, true) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) @@ -154,3 +159,57 @@ func TestIgmpV1Present(t *testing.T) { } validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) } + +func TestSendQueuedIGMPReports(t *testing.T) { + e, s, clock := createStack(t, true) + + // Joining a group without an assigned address should queue IGMP packets; none + // should be sent without an assigned address. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) + } + reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport + if got := reportStat.Value(); got != 0 { + t.Errorf("got reportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } + + // The initial set of IGMP reports that were queued should be sent once an + // address is assigned. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + if got := reportStat.Value(); got != 1 { + t.Errorf("got reportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + if got := reportStat.Value(); got != 2 { + t.Errorf("got reportStat.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + + // Should have no more packets to send after the initial set of unsolicited + // reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 3076185cd..e9ff70d04 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -72,7 +72,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - igmp igmpState // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -84,6 +83,7 @@ type endpoint struct { sync.RWMutex addressableEndpointState stack.AddressableEndpointState + igmp igmpState } } @@ -94,8 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.igmp.init(e, p.options.IGMP) + e.mu.igmp.init(e) + e.mu.Unlock() return e } @@ -127,7 +129,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of IGMP when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.igmp.initializeAll() + e.mu.igmp.initializeAll() // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the @@ -170,7 +172,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.isEnabled() { return } @@ -181,12 +183,16 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of IGMP so that routers know that // we are no longer interested in the group. - e.igmp.softLeaveAll() + e.mu.igmp.softLeaveAll() // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -718,7 +724,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } if p == header.IGMPProtocolNumber { - e.igmp.handleIGMP(pkt) + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() return } if opts := h.Options(); len(opts) != 0 { @@ -776,7 +784,12 @@ func (e *endpoint) Close() { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err == nil { + e.mu.igmp.sendQueuedReports() + } + return ep, err } // RemovePermanentAddress implements stack.AddressableEndpoint. @@ -811,6 +824,14 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { e.mu.RLock() defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements +// +// Precondition: igmp.ep.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) } @@ -843,7 +864,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.igmp.joinGroup(addr) + e.mu.igmp.joinGroup(addr) return nil } @@ -858,14 +879,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.igmp.leaveGroup(addr) + return e.mu.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.igmp.isInGroup(addr) + return e.mu.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9e2d2cfd6..ef62fe6fc 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -2669,8 +2669,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2712,8 +2712,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2761,8 +2761,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress) + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 5e75c8740..afa45aefe 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -58,7 +58,10 @@ go_test( srcs = ["mld_test.go"], deps = [ ":ipv6", + "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 510276b8e..6ee162713 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -645,26 +645,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: - var handler func(header.MLD) switch icmpType { case header.ICMPv6MulticastListenerQuery: received.MulticastListenerQuery.Increment() - handler = e.mld.handleMulticastListenerQuery case header.ICMPv6MulticastListenerReport: received.MulticastListenerReport.Increment() - handler = e.mld.handleMulticastListenerReport case header.ICMPv6MulticastListenerDone: received.MulticastListenerDone.Increment() default: panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { received.Invalid.Increment() return } - if handler != nil { - handler(header.MLD(payload.ToView())) + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + e.mu.Lock() + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerReport: + e.mu.Lock() + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerDone: + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } default: diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 32adb5c83..34a6a8446 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -149,9 +149,8 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - } + var r stack.Route + r.NetProto = protocol r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -296,11 +295,11 @@ func TestICMPCounts(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) ep.HandlePacket(pkt) } @@ -454,11 +453,11 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) ep.HandlePacket(pkt) } @@ -600,8 +599,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr { - t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr) + if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr { + t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } // Pull the full payload since network header. Needed for header.IPv6 to @@ -853,11 +852,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), @@ -930,11 +929,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1048,11 +1047,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(icmpSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1108,11 +1107,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1227,11 +1226,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(size + payloadSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), @@ -1381,8 +1380,8 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } - if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } if pkt.Route.RemoteAddress != test.expectedRemoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) @@ -1445,11 +1444,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1463,8 +1462,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1487,11 +1486,11 @@ func TestPacketQueing(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1505,8 +1504,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1556,8 +1555,8 @@ func TestPacketQueing(t *testing.T) { t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) } snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1586,11 +1585,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1828,11 +1827,11 @@ func TestCallsToNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.source, - DstAddr: test.destination, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.source, + DstAddr: test.destination, }) ep.HandlePacket(pkt) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 8bf84601f..f2018d073 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "fmt" "hash/fnv" + "math" "sort" "sync/atomic" "time" @@ -60,6 +61,108 @@ const ( buckets = 2048 ) +// policyTable is the default policy table defined in RFC 6724 section 2.1. +// +// A more human-readable version: +// +// Prefix Precedence Label +// ::1/128 50 0 +// ::/0 40 1 +// ::ffff:0:0/96 35 4 +// 2002::/16 30 2 +// 2001::/32 5 5 +// fc00::/7 3 13 +// ::/96 1 3 +// fec0::/10 1 11 +// 3ffe::/16 1 12 +// +// The table is sorted by prefix length so longest-prefix match can be easily +// achieved. +// +// We willingly left out ::/96, fec0::/10 and 3ffe::/16 since those prefix +// assignments are deprecated. +// +// As per RFC 4291 section 2.5.5.1 (for ::/96), +// +// The "IPv4-Compatible IPv6 address" is now deprecated because the +// current IPv6 transition mechanisms no longer use these addresses. +// New or updated implementations are not required to support this +// address type. +// +// As per RFC 3879 section 4 (for fec0::/10), +// +// This document formally deprecates the IPv6 site-local unicast prefix +// defined in [RFC3513], i.e., 1111111011 binary or FEC0::/10. +// +// As per RFC 3701 section 1 (for 3ffe::/16), +// +// As clearly stated in [TEST-NEW], the addresses for the 6bone are +// temporary and will be reclaimed in the future. It further states +// that all users of these addresses (within the 3FFE::/16 prefix) will +// be required to renumber at some time in the future. +// +// and section 2, +// +// Thus after the pTLA allocation cutoff date January 1, 2004, it is +// REQUIRED that no new 6bone 3FFE pTLAs be allocated. +// +// MUST NOT BE MODIFIED. +var policyTable = [...]struct { + subnet tcpip.Subnet + + label uint8 +}{ + // ::1/128 + { + subnet: header.IPv6Loopback.WithPrefix().Subnet(), + label: 0, + }, + // ::ffff:0:0/96 + { + subnet: header.IPv4MappedIPv6Subnet, + label: 4, + }, + // 2001::/32 (Teredo prefix as per RFC 4380 section 2.6). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 32, + }.Subnet(), + label: 5, + }, + // 2002::/16 (6to4 prefix as per RFC 3056 section 2). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 16, + }.Subnet(), + label: 2, + }, + // fc00::/7 (Unique local addresses as per RFC 4193 section 3.1). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 7, + }.Subnet(), + label: 13, + }, + // ::/0 + { + subnet: header.IPv6EmptySubnet, + label: 1, + }, +} + +func getLabel(addr tcpip.Address) uint8 { + for _, p := range policyTable { + if p.subnet.Contains(addr) { + return p.label + } + } + + panic(fmt.Sprintf("should have a label for address = %s", addr)) +} + var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -85,9 +188,8 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState + mld mldState } - - mld mldState } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -122,6 +224,45 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// onAddressAssignedLocked handles an address being assigned. +// +// Precondition: e.mu must be exclusively locked. +func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, ... + // + // If we just completed DAD for a link-local address, then attempt to send any + // queued MLD reports. Note, we may have sent reports already for some of the + // groups before we had a valid link-local address to use as the source for + // the MLD messages, but that was only so that MLD snooping switches are aware + // of our membership to groups - routers would not have handled those reports. + // + // As per RFC 3590 section 4, + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + if header.IsV6LinkLocalAddress(addr) { + e.mu.mld.sendQueuedReports() + } +} + // InvalidateDefaultRouter implements stack.NDPEndpoint. func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() @@ -232,7 +373,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of MLD when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.mld.initializeAll() + e.mu.mld.initializeAll() // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives @@ -334,7 +475,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.Enabled() { return } @@ -349,7 +490,11 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. - e.mld.softLeaveAll() + e.mu.mld.softLeaveAll() + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -389,19 +534,27 @@ func (e *endpoint) MTU() uint32 { // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { + // TODO(gvisor.dev/issues/5035): The maximum header length returned here does + // not open the possibility for the caller to know about size required for + // extension headers. return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) { + extHdrsLen := extensionHeaders.Length() + length := pkt.Size() + extensionHeaders.Length() + if length > math.MaxUint16 { + panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16)) + } + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(params.Protocol), - HopLimit: params.TTL, - TrafficClass: params.TOS, - SrcAddr: srcAddr, - DstAddr: dstAddr, + PayloadLength: uint16(length), + TransportProtocol: params.Protocol, + HopLimit: params.TTL, + TrafficClass: params.TOS, + SrcAddr: srcAddr, + DstAddr: dstAddr, + ExtensionHeaders: extensionHeaders, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -456,7 +609,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */) // iptables filtering. All packets that reach here are locally // generated. @@ -545,7 +698,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -1177,13 +1330,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre return addressEndpoint, nil } - snmc := header.SolicitedNodeAddr(addr.Address) - if err := e.joinGroupLocked(snmc); err != nil { - // joinGroupLocked only returns an error if the group address is not a valid - // IPv6 multicast address. - panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) - } - addressEndpoint.SetKind(stack.PermanentTentative) if e.Enabled() { @@ -1192,6 +1338,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } } + snmc := header.SolicitedNodeAddr(addr.Address) + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) + } + return addressEndpoint, nil } @@ -1293,6 +1446,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) } +// getLinkLocalAddressRLocked returns a link-local address from the primary list +// of addresses, if one is available. +// +// See stack.PrimaryEndpointBehavior for more details about the primary list. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { + var linkLocalAddr tcpip.Address + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.IsAssigned(false /* allowExpired */) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + linkLocalAddr = addr + return false + } + } + return true + }) + return linkLocalAddr +} + // acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress // but with locking requirements. // @@ -1302,7 +1475,11 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // RFC 6724 section 5. type addrCandidate struct { addressEndpoint stack.AddressEndpoint + addr tcpip.Address scope header.IPv6AddressScope + + label uint8 + matchingPrefix uint8 } if len(remoteAddr) == 0 { @@ -1312,10 +1489,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { - return + return true } addr := addressEndpoint.AddressWithPrefix().Address @@ -1329,8 +1506,13 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address cs = append(cs, addrCandidate{ addressEndpoint: addressEndpoint, + addr: addr, scope: scope, + label: getLabel(addr), + matchingPrefix: remoteAddr.MatchingPrefix(addr), }) + + return true }) remoteScope, err := header.ScopeForIPv6Address(remoteAddr) @@ -1339,18 +1521,20 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) } + remoteLabel := getLabel(remoteAddr) + // Sort the addresses as per RFC 6724 section 5 rules 1-3. // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + // TODO(b/146021396): Implement rules 4, 5 of RFC 6724 section 5. sort.Slice(cs, func(i, j int) bool { sa := cs[i] sb := cs[j] // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + if sa.addr == remoteAddr { return true } - if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + if sb.addr == remoteAddr { return false } @@ -1367,11 +1551,29 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address return sbDep } + // Prefer matching label as per RFC 6724 section 5 rule 6. + if sa, sb := sa.label == remoteLabel, sb.label == remoteLabel; sa != sb { + if sa { + return true + } + if sb { + return false + } + } + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { return saTemp } + // Use longest matching prefix as per RFC 6724 section 5 rule 8. + if sa.matchingPrefix > sb.matchingPrefix { + return true + } + if sb.matchingPrefix > sa.matchingPrefix { + return false + } + // sa and sb are equal, return the endpoint that is closest to the front of // the primary endpoint list. return i < j @@ -1417,7 +1619,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.mld.joinGroup(addr) + e.mu.mld.joinGroup(addr) return nil } @@ -1432,14 +1634,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.mld.leaveGroup(addr) + return e.mu.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mld.isInGroup(addr) + return e.mu.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -1504,17 +1706,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp = ndpState{ - ep: e, - configs: p.options.NDPConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - e.mu.ndp.initializeTempAddrState() - e.mld.init(e, p.options.MLD) + e.mu.ndp.init(e) + e.mu.mld.init(e) + e.mu.Unlock() p.mu.Lock() defer p.mu.Unlock() @@ -1735,24 +1931,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea fragPkt.NetworkProtocolNumber = ProtocolNumber originalIPHeadersLength := len(originalIPHeaders) - fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + + s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + M: more, + Identification: id, + }} + + fragmentIPHeadersLength := originalIPHeadersLength + s.Length() fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) - fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) } - fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) - fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) - fragmentHeader.Encode(&header.IPv6FragmentFields{ - M: more, - FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), - Identification: id, - NextHeader: uint8(transportProto), - }) + nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:]) + + fragmentIPHeaders.SetNextHeader(nextHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) return fragPkt, more } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1c01f17ab..5f07d3af8 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -69,11 +69,11 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -127,11 +127,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -915,10 +915,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), - NextHeader: ipv6NextHdr, - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, + // We're lying about transport protocol here to be able to generate + // raw extension headers from the test definitions. + TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), + HopLimit: 255, + SrcAddr: addr1, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1947,10 +1949,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, + // We're lying about transport protocol here so that we can generate + // raw extension headers for the tests. + TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, }) vv := hdr.View().ToVectorisedView() @@ -1995,7 +1999,7 @@ func TestInvalidIPv6Fragments(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -2014,14 +2018,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 9, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 9, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0 >> 3, M: true, Identification: ident, @@ -2041,14 +2044,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, M: false, Identification: ident, @@ -2089,10 +2091,9 @@ func TestInvalidIPv6Fragments(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) - - fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2154,7 +2155,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -2168,14 +2169,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2190,14 +2190,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2206,14 +2205,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2228,14 +2226,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2250,14 +2247,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2266,14 +2262,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2288,14 +2283,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2304,14 +2298,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2350,10 +2343,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2994,11 +2988,11 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: test.TTL, + SrcAddr: remoteIPv6Addr1, + DstAddr: remoteIPv6Addr2, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 4c06b3f0c..e8d1e7a79 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -40,6 +40,9 @@ type MLDOptions struct { // When enabled, MLD may transmit MLD report and done messages when // joining and leaving multicast groups respectively, and handle incoming // MLD packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). Enabled bool } @@ -55,22 +58,35 @@ type mldState struct { genericMulticastProtocol ip.GenericMulticastProtocolState } +// Enabled implements ip.MulticastGroupProtocol. +func (mld *mldState) Enabled() bool { + // No need to perform MLD on loopback interfaces since they don't have + // neighbouring nodes. + return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled() +} + // SendReport implements ip.MulticastGroupProtocol. -func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { - return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + return err } // init sets up an mldState struct, and is required to be called before using // a new mldState. -func (mld *mldState) init(ep *endpoint, opts MLDOptions) { +// +// Must only be called once for the lifetime of mld. +func (mld *mldState) init(ep *endpoint) { mld.ep = ep - mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: mld, @@ -79,33 +95,45 @@ func (mld *mldState) init(ep *endpoint, opts MLDOptions) { }) } +// handleMulticastListenerQuery handles a query message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) + mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) } +// handleMulticastListenerReport handles a report message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) + mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress()) } // joinGroup handles joining a new group and sending and scheduling the required // messages. // // If the group is already joined, returns tcpip.ErrDuplicateAddress. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { - mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */) + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { - return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Done message, if // required. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // LeaveGroup returns false only if the group was not joined. - if mld.genericMulticastProtocol.LeaveGroup(groupAddress) { + if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -114,17 +142,31 @@ func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of MLD, but remains // joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) softLeaveAll() { - mld.genericMulticastProtocol.MakeAllNonMember() + mld.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the MLD state for each group that has // been joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) initializeAll() { - mld.genericMulticastProtocol.InitializeGroups() + mld.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) sendQueuedReports() { + mld.genericMulticastProtocol.SendQueuedReportsLocked() } -func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { +// writePacket assembles and sends an MLD packet. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent var mldStat *tcpip.StatCounter switch mldType { @@ -139,26 +181,82 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) icmp.SetType(mldType) header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddress := header.IPv6Any + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert + // option in a Hop-by-Hop Options header. + // + // However, this would cause problems with Duplicate Address Detection with + // the first address as MLD snooping switches may not send multicast traffic + // that DAD depends on to the node performing DAD without the MLD report, as + // documented in RFC 4816: + // + // Note that when a node joins a multicast address, it typically sends a + // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810] + // for the multicast address. In the case of Duplicate Address + // Detection, the MLD report message is required in order to inform MLD- + // snooping switches, rather than routers, to forward multicast packets. + // In the above description, the delay for joining the multicast address + // thus means delaying transmission of the corresponding MLD report + // message. Since the MLD specifications do not request a random delay + // to avoid race conditions, just delaying Neighbor Solicitation would + // cause congestion by the MLD report messages. The congestion would + // then prevent the MLD-snooping switches from working correctly and, as + // a result, prevent Duplicate Address Detection from working. The + // requirement to include the delay for the MLD report in this case + // avoids this scenario. [RFC3590] also talks about some interaction + // issues between Duplicate Address Detection and MLD, and specifies + // which source address should be used for the MLD report in this case. + // + // As per RFC 3590 section 4, we should still send out MLD reports with an + // unspecified source address if we do not have an assigned link-local + // address to use as the source address to ensure DAD works as expected on + // networks with MLD snooping switches: + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + localAddress := mld.ep.getLinkLocalAddressRLocked() + if len(localAddress) == 0 { + localAddress = header.IPv6Any + } + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + extensionHeaders := header.IPv6ExtHdrSerializer{ + header.IPv6SerializableHopByHopExtHdr{ + &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, + }, + } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()), + ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(), Data: buffer.View(icmp).ToVectorisedView(), }) mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.MLDHopLimit, - }) - // TODO(b/162198658): set the ROUTER_ALERT option when sending Host - // Membership Reports. + }, extensionHeaders) if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sentStats.Dropped.Increment() - return err + return false, err } mldStat.Increment() - return nil + return localAddress != header.IPv6Any, nil } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 5677bdd54..e2778b656 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,8 +16,12 @@ package ipv6_test import ( "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -25,9 +29,34 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) +var ( + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) + globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) +) + +func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { + t.Helper() + + checker.IPv6WithExtHdr(t, p, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(localAddress), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(mldType, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { const nicID = 1 @@ -46,45 +75,223 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a report message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(snmc), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) } // The stack will leave an address's solicited node multicast address when // an address is removed. An MLD done message should be sent for the // solicited-node group. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a done message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a done message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) +} + +func TestSendQueuedMLDReports(t *testing.T) { + const ( + nicID = 1 + maxReports = 2 + ) + + tests := []struct { + name string + dadTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD Disabled", + dadTransmits: 0, + retransmitTimer: 0, + }, + { + name: "DAD Enabled", + dadTransmits: 1, + retransmitTimer: time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dadTransmits, + RetransmitTimer: test.retransmitTimer, + }, + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + Clock: clock, + }) + + // Allow space for an extra packet so we can observe packets that were + // unexpectedly sent. + e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + resolveDAD := func(addr, snmc tcpip.Address) { + clock.Advance(dadResolutionTime) + if p, ok := e.Read(); !ok { + t.Fatal("expected DAD packet") + } else { + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr), + checker.NDPNSOptions(nil), + )) + } + } + + var reportCounter uint64 + reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + var doneCounter uint64 + doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + + // Joining a group without an assigned address should send an MLD report + // with the unspecified address. + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalMulticastAddr) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a global address should not send reports for the already joined + // group since we should only send queued reports when a link-local + // addres sis assigned. + // + // Note, we will still expect to send a report for the global address's + // solicited node address from the unspecified address as per RFC 3590 + // section 4. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) + } + if dadResolutionTime != 0 { + // Reports should not be sent when the address resolves. + resolveDAD(globalAddr, globalAddrSNMC) + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + } + // Leave the group since we don't care about the global address's + // solicited node multicast group membership. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) + } + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a link-local address should send a report for its solicited node + // address and globalMulticastAddr. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + } + if dadResolutionTime != 0 { + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + resolveDAD(linkLocalAddr, linkLocalAddrSNMC) + } + + // We expect two batches of reports to be sent (1 batch when the + // link-local address is assigned, and another after the maximum + // unsolicited report interval. + for i := 0; i < 2; i++ { + // We expect reports to be sent (one for globalMulticastAddr and another + // for linkLocalAddrSNMC). + reportCounter += maxReports + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + + addrs := map[tcpip.Address]bool{ + globalMulticastAddr: false, + linkLocalAddrSNMC: false, + } + for _ = range addrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) + } + + addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() + if seen, ok := addrs[addr]; !ok { + t.Fatalf("got unexpected packet destined to %s", addr) + } else if seen { + t.Fatalf("got another packet destined to %s", addr) + } + + addrs[addr] = true + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) + + clock.Advance(ipv6.UnsolicitedReportIntervalMax) + } + } + + // Should not send any more reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + }) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8cb7d4dab..d515eb622 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -20,6 +20,7 @@ import ( "math/rand" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + // The IPv6 endpoint this ndpState is for. ep *endpoint @@ -643,6 +647,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } + ndp.ep.onAddressAssignedLocked(addr) return nil } @@ -686,12 +691,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + if dadDone { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } + + ndp.ep.onAddressAssignedLocked(addr) } }), } @@ -728,7 +735,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }) + }, nil /* extensionHeaders */) if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() @@ -1850,7 +1857,7 @@ func (ndp *ndpState) startSolicitingRouters() { ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }) + }, nil /* extensionHeaders */) if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() @@ -1884,11 +1891,19 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitJob = nil } -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) +func (ndp *ndpState) init(ep *endpoint) { + if ndp.dad != nil { + panic("attempted to initialize NDP state twice") + } + ndp.ep = ep + ndp.configs = ep.protocol.options.NDPConfigs + ndp.dad = make(map[tcpip.Address]dadState) + ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) + ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) + + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 95c626bb8..7ddb19c00 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -213,11 +213,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -319,11 +319,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -599,11 +599,11 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != respNSDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) } - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -681,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, }) e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != test.naDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } - if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr) + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -785,11 +785,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -898,11 +898,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -979,29 +979,25 @@ func TestNDPValidation(t *testing.T) { } handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - var extensions buffer.View + var extHdrs header.IPv6ExtHdrSerializer if atomicFragment { - extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) - extensions[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) } + extHdrsLen := extHdrs.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, Data: payload.ToVectorisedView(), }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + len(extensions)), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, }) - if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { - t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) - } ep.HandlePacket(pkt) } @@ -1351,11 +1347,11 @@ func TestRouterAdvertValidation(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, }) stats := s.Stats().ICMP.V6.PacketsReceived diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 95fb67986..05d98a0a5 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -34,6 +35,9 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + ipv4Addr = tcpip.Address("\x0a\x00\x00\x01") + ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") @@ -48,6 +52,8 @@ const ( mldQuery = uint8(header.ICMPv6MulticastListenerQuery) mldReport = uint8(header.ICMPv6MulticastListenerReport) mldDone = uint8(header.ICMPv6MulticastListenerDone) + + maxUnsolicitedReports = 2 ) var ( @@ -61,6 +67,8 @@ var ( } return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) }() + + ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) ) // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet @@ -69,7 +77,11 @@ func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.A t.Helper() payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv6(t, payload, + checker.IPv6WithExtHdr(t, payload, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(ipv6Addr), checker.DstAddr(remoteAddress), // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. checker.TTL(1), @@ -87,6 +99,7 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(ipv4Addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -99,23 +112,31 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. ) } -func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { +func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() - // Create an endpoint of queue size 2, since no more than 2 packets are ever - // queued in the tests in this file. - e := channel.New(2, 1280, linkAddr) + e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) + s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) + return e, s, clock +} + +func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { + t.Helper() + + igmpEnabled := v4 && mgpEnabled + mldEnabled := !v4 && mgpEnabled + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocolWithOptions(ipv4.Options{ IGMP: ipv4.IGMPOptions{ - Enabled: mgpEnabled, + Enabled: igmpEnabled, }, }), ipv6.NewProtocolWithOptions(ipv6.Options{ MLD: ipv6.MLDOptions{ - Enabled: mgpEnabled, + Enabled: mldEnabled, }, }), }, @@ -124,8 +145,59 @@ func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) + } - return e, s, clock + return s, clock +} + +// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join +// when it is created with an IPv6 address. +// +// To not interfere with tests, checkInitialIPv6Groups will leave the added +// address's solicited node multicast group so that the tests can all assume +// the NIC has not joined any IPv6 groups. +func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { + t.Helper() + + stats := s.Stats().ICMP.V6.PacketsSent + + reportCounter++ + if got := stats.MulticastListenerReport.Value(); got != reportCounter { + t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) + } + + // Leave the group to not affect the tests. This is fine since we are not + // testing DAD or the solicited node address specifically. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) + } + leaveCounter++ + if got := stats.MulticastListenerDone.Value(); got != leaveCounter { + t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + return reportCounter, leaveCounter } // createAndInjectIGMPPacket creates and injects an IGMP packet with the @@ -170,11 +242,11 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay b ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - HopLimit: header.MLDHopLimit, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - SrcAddr: header.IPv4Any, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(icmpSize), + HopLimit: header.MLDHopLimit, + TransportProtocol: header.ICMPv6ProtocolNumber, + SrcAddr: header.IPv4Any, + DstAddr: header.IPv6AllNodesMulticastAddress, }) icmp := header.ICMPv6(buf[header.IPv6MinimumSize:]) @@ -232,13 +304,13 @@ func TestMGPDisabled(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, false) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) // This NIC may join multicast groups when it is enabled but since MGP is // disabled, no reports should be sent. sentReportStat := test.sentReportStat(s) if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -251,7 +323,7 @@ func TestMGPDisabled(t *testing.T) { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -355,7 +427,7 @@ func TestMGPReceiveCounters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) + e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) if got := test.statCounter(s).Value(); got != 1 { @@ -376,6 +448,7 @@ func TestMGPJoinGroup(t *testing.T) { sentReportStat func(*stack.Stack) *tcpip.StatCounter receivedQueryStat func(*stack.Stack) *tcpip.StatCounter validateReport func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -410,21 +483,28 @@ func TestMGPJoinGroup(t *testing.T) { validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } // Test joining a specific address explicitly and verify a Report is sent // immediately. if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } + reportCounter++ sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportState.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -442,8 +522,9 @@ func TestMGPJoinGroup(t *testing.T) { t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) } clock.Advance(test.maxUnsolicitedResponseDelay) - if got := sentReportStat.Value(); got != 2 { - t.Errorf("got sentReportState.Value() = %d, want = 2", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -464,13 +545,14 @@ func TestMGPJoinGroup(t *testing.T) { // group the stack sends a leave/done message. func TestMGPLeaveGroup(t *testing.T) { tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - validateLeave func(*testing.T, channel.PacketInfo) + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -513,18 +595,26 @@ func TestMGPLeaveGroup(t *testing.T) { validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } - if got := test.sentReportStat(s).Value(); got != 1 { - t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got) + reportCounter++ + if got := test.sentReportStat(s).Value(); got != reportCounter { + t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -539,8 +629,9 @@ func TestMGPLeaveGroup(t *testing.T) { if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } - if got := test.sentLeaveStat(s).Value(); got != 1 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got) + leaveCounter++ + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a leave message to be sent") @@ -570,6 +661,7 @@ func TestMGPQueryMessages(t *testing.T) { rxQuery func(*channel.Endpoint, uint8, tcpip.Address) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -614,6 +706,7 @@ func TestMGPQueryMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } @@ -647,16 +740,22 @@ func TestMGPQueryMessages(t *testing.T) { for _, subTest := range subTests { t.Run(subTest.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - for i := uint64(1); i <= 2; i++ { + for i := 0; i < maxUnsolicitedReports; i++ { sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != i { - t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatalf("expected %d-th report message to be sent", i) @@ -686,8 +785,9 @@ func TestMGPQueryMessages(t *testing.T) { if subTest.expectReport { clock.Advance(test.maxRespTimeToDuration(maxRespTime)) - if got := sentReportStat.Value(); got != 3 { - t.Errorf("got sentReportState.Value() = %d, want = 3", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -719,6 +819,7 @@ func TestMGPReportMessages(t *testing.T) { rxReport func(*channel.Endpoint) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -761,19 +862,27 @@ func TestMGPReportMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -788,8 +897,8 @@ func TestMGPReportMessages(t *testing.T) { // reports. test.rxReport(e) clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); ok { t.Errorf("sent unexpected packet = %#v", p) @@ -804,8 +913,8 @@ func TestMGPReportMessages(t *testing.T) { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } clock.Advance(time.Hour) - if got := test.sentLeaveStat(s).Value(); got != 0 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got) + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } // Should not send any more packets. @@ -829,6 +938,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -897,10 +1007,31 @@ func TestMGPWithNICLifecycle(t *testing.T) { t.Helper() ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - if got := tcpip.TransportProtocolNumber(ipv6.NextHeader()); got != header.ICMPv6ProtocolNumber { + + ipv6HeaderIter := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var transport header.IPv6RawPayloadHeader + for { + h, done, err := ipv6HeaderIter.Next() + if err != nil { + t.Fatalf("ipv6HeaderIter.Next(): %s", err) + } + if done { + t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) + } + if t, ok := h.(header.IPv6RawPayloadHeader); ok { + transport = t + break + } + } + + if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) } - icmpv6 := header.ICMPv6(ipv6.Payload()) + icmpv6 := header.ICMPv6(transport.Buf.ToView()) if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) } @@ -914,17 +1045,22 @@ func TestMGPWithNICLifecycle(t *testing.T) { } seen[addr] = true return addr - }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - sentReportStat := test.sentReportStat(s) var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + sentReportStat := test.sentReportStat(s) for _, a := range test.multicastAddrs { if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) @@ -949,7 +1085,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { t.Fatalf("DisableNIC(%d): %s", nicID, err) } sentLeaveStat := test.sentLeaveStat(s) - leaveCounter := uint64(len(test.multicastAddrs)) + leaveCounter += uint64(len(test.multicastAddrs)) if got := sentLeaveStat.Value(); got != leaveCounter { t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) } @@ -1051,7 +1187,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { clock.Advance(test.maxUnsolicitedResponseDelay) reportCounter++ if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter) + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -1067,3 +1203,59 @@ func TestMGPWithNICLifecycle(t *testing.T) { }) } } + +// TestMGPDisabledOnLoopback tests that the multicast group protocol is not +// performed on loopback interfaces since they have no neighbours. +func TestMGPDisabledOnLoopback(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) + + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + }) + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index c53698a6a..f3ad40fdf 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -16,6 +16,8 @@ package tcpip import ( "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" ) // SocketOptionsHandler holds methods that help define endpoint specific @@ -37,6 +39,15 @@ type SocketOptionsHandler interface { // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint. OnCorkOptionSet(v bool) + + // LastError is invoked when SO_ERROR is read for an endpoint. + LastError() *Error + + // UpdateLastError updates the endpoint specific last error field. + UpdateLastError(err *Error) + + // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. + HasNIC(v int32) bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -60,6 +71,19 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} // OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet. func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} +// LastError implements SocketOptionsHandler.LastError. +func (*DefaultSocketOptionsHandler) LastError() *Error { + return nil +} + +// UpdateLastError implements SocketOptionsHandler.UpdateLastError. +func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} + +// HasNIC implements SocketOptionsHandler.HasNIC. +func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { + return false +} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -69,24 +93,24 @@ type SocketOptions struct { // These fields are accessed and modified using atomic operations. - // broadcastEnabled determines whether datagram sockets are allowed to send - // packets to a broadcast address. + // broadcastEnabled determines whether datagram sockets are allowed to + // send packets to a broadcast address. broadcastEnabled uint32 - // passCredEnabled determines whether SCM_CREDENTIALS socket control messages - // are enabled. + // passCredEnabled determines whether SCM_CREDENTIALS socket control + // messages are enabled. passCredEnabled uint32 // noChecksumEnabled determines whether UDP checksum is disabled while // transmitting for this socket. noChecksumEnabled uint32 - // reuseAddressEnabled determines whether Bind() should allow reuse of local - // address. + // reuseAddressEnabled determines whether Bind() should allow reuse of + // local address. reuseAddressEnabled uint32 - // reusePortEnabled determines whether to permit multiple sockets to be bound - // to an identical socket address. + // reusePortEnabled determines whether to permit multiple sockets to be + // bound to an identical socket address. reusePortEnabled uint32 // keepAliveEnabled determines whether TCP keepalive is enabled for this @@ -94,7 +118,7 @@ type SocketOptions struct { keepAliveEnabled uint32 // multicastLoopEnabled determines whether multicast packets sent over a - // non-loopback interface will be looped back. Analogous to inet->mc_loop. + // non-loopback interface will be looped back. multicastLoopEnabled uint32 // receiveTOSEnabled is used to specify if the TOS ancillary message is @@ -130,6 +154,28 @@ type SocketOptions struct { // corkOptionEnabled is used to specify if data should be held until segments // are full by the TCP transport protocol. corkOptionEnabled uint32 + + // receiveOriginalDstAddress is used to specify if the original destination of + // the incoming packet should be returned as an ancillary message. + receiveOriginalDstAddress uint32 + + // recvErrEnabled determines whether extended reliable error message passing + // is enabled. + recvErrEnabled uint32 + + // errQueue is the per-socket error queue. It is protected by errQueueMu. + errQueueMu sync.Mutex `state:"nosave"` + errQueue sockErrorList + + // bindToDevice determines the device to which the socket is bound. + bindToDevice int32 + + // mu protects the access to the below fields. + mu sync.Mutex `state:"nosave"` + + // linger determines the amount of time the socket should linger before + // close. We currently implement this option for TCP socket only. + linger LingerOption } // InitHandler initializes the handler. This must be called before using the @@ -146,6 +192,11 @@ func storeAtomicBool(addr *uint32, v bool) { atomic.StoreUint32(addr, val) } +// SetLastError sets the last error for a socket. +func (so *SocketOptions) SetLastError(err *Error) { + so.handler.UpdateLastError(err) +} + // GetBroadcast gets value for SO_BROADCAST option. func (so *SocketOptions) GetBroadcast() bool { return atomic.LoadUint32(&so.broadcastEnabled) != 0 @@ -302,3 +353,168 @@ func (so *SocketOptions) SetCorkOption(v bool) { storeAtomicBool(&so.corkOptionEnabled, v) so.handler.OnCorkOptionSet(v) } + +// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) GetReceiveOriginalDstAddress() bool { + return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0 +} + +// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { + storeAtomicBool(&so.receiveOriginalDstAddress, v) +} + +// GetRecvError gets value for IP*_RECVERR option. +func (so *SocketOptions) GetRecvError() bool { + return atomic.LoadUint32(&so.recvErrEnabled) != 0 +} + +// SetRecvError sets value for IP*_RECVERR option. +func (so *SocketOptions) SetRecvError(v bool) { + storeAtomicBool(&so.recvErrEnabled, v) + if !v { + so.pruneErrQueue() + } +} + +// GetLastError gets value for SO_ERROR option. +func (so *SocketOptions) GetLastError() *Error { + return so.handler.LastError() +} + +// GetOutOfBandInline gets value for SO_OOBINLINE option. +func (*SocketOptions) GetOutOfBandInline() bool { + return true +} + +// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not +// support disabling this option. +func (*SocketOptions) SetOutOfBandInline(bool) {} + +// GetLinger gets value for SO_LINGER option. +func (so *SocketOptions) GetLinger() LingerOption { + so.mu.Lock() + linger := so.linger + so.mu.Unlock() + return linger +} + +// SetLinger sets value for SO_LINGER option. +func (so *SocketOptions) SetLinger(linger LingerOption) { + so.mu.Lock() + so.linger = linger + so.mu.Unlock() +} + +// SockErrOrigin represents the constants for error origin. +type SockErrOrigin uint8 + +const ( + // SockExtErrorOriginNone represents an unknown error origin. + SockExtErrorOriginNone SockErrOrigin = iota + + // SockExtErrorOriginLocal indicates a local error. + SockExtErrorOriginLocal + + // SockExtErrorOriginICMP indicates an IPv4 ICMP error. + SockExtErrorOriginICMP + + // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error. + SockExtErrorOriginICMP6 +) + +// IsICMPErr indicates if the error originated from an ICMP error. +func (origin SockErrOrigin) IsICMPErr() bool { + return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 +} + +// SockError represents a queue entry in the per-socket error queue. +// +// +stateify savable +type SockError struct { + sockErrorEntry + + // Err is the error caused by the errant packet. + Err *Error + // ErrOrigin indicates the error origin. + ErrOrigin SockErrOrigin + // ErrType is the type in the ICMP header. + ErrType uint8 + // ErrCode is the code in the ICMP header. + ErrCode uint8 + // ErrInfo is additional info about the error. + ErrInfo uint32 + + // Payload is the errant packet's payload. + Payload []byte + // Dst is the original destination address of the errant packet. + Dst FullAddress + // Offender is the original sender address of the errant packet. + Offender FullAddress + // NetProto is the network protocol being used to transmit the packet. + NetProto NetworkProtocolNumber +} + +// pruneErrQueue resets the queue. +func (so *SocketOptions) pruneErrQueue() { + so.errQueueMu.Lock() + so.errQueue.Reset() + so.errQueueMu.Unlock() +} + +// DequeueErr dequeues a socket extended error from the error queue and returns +// it. Returns nil if queue is empty. +func (so *SocketOptions) DequeueErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + + err := so.errQueue.Front() + if err != nil { + so.errQueue.Remove(err) + } + return err +} + +// PeekErr returns the error in the front of the error queue. Returns nil if +// the error queue is empty. +func (so *SocketOptions) PeekErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + return so.errQueue.Front() +} + +// QueueErr inserts the error at the back of the error queue. +// +// Preconditions: so.GetRecvError() == true. +func (so *SocketOptions) QueueErr(err *SockError) { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + so.errQueue.PushBack(err) +} + +// QueueLocalErr queues a local error onto the local queue. +func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { + so.QueueErr(&SockError{ + Err: err, + ErrOrigin: SockExtErrorOriginLocal, + ErrInfo: info, + Payload: payload, + Dst: dst, + NetProto: net, + }) +} + +// GetBindToDevice gets value for SO_BINDTODEVICE option. +func (so *SocketOptions) GetBindToDevice() int32 { + return atomic.LoadInt32(&so.bindToDevice) +} + +// SetBindToDevice sets value for SO_BINDTODEVICE option. +func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { + if !so.handler.HasNIC(bindToDevice) { + return ErrUnknownDevice + } + + atomic.StoreInt32(&so.bindToDevice, bindToDevice) + return nil +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9cc6074da..bb30556cf 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -148,7 +148,6 @@ go_test( ], library = ":stack", deps = [ - "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 6e4f5fa46..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -82,12 +82,16 @@ func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) } // ForEachPrimaryEndpoint calls f for each primary address. -func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { +// +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { a.mu.RLock() defer a.mu.RUnlock() for _, ep := range a.mu.primary { - f(ep) + if !f(ep) { + return + } } } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 5ec9b3411..93e8e1c51 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -560,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) { } } +func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 50 * time.Millisecond, + onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + // Don't resolve the link address. + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + + const numPackets int = 5 + // These packets will all be enqueued in the packet queue to wait for link + // address resolution. + for i := 0; i < numPackets; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + // All packets should fail resolution. + // TODO(gvisor.dev/issue/5141): Use a fake clock. + for i := 0; i < numPackets; i++ { + select { + case got := <-ep2.C: + t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) + case <-time.After(100 * time.Millisecond): + } + } +} + func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index c9b13cd0e..792f4f170 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -18,7 +18,6 @@ import ( "fmt" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -58,9 +57,6 @@ const ( incomplete entryState = iota // ready means that the address has been resolved and can be used. ready - // failed means that address resolution timed out and the address - // could not be resolved. - failed ) // String implements Stringer. @@ -70,8 +66,6 @@ func (s entryState) String() string { return "incomplete" case ready: return "ready" - case failed: - return "failed" default: return fmt.Sprintf("unknown(%d)", s) } @@ -80,40 +74,48 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry + // TODO(gvisor.dev/issue/5150): move these fields under mu. + // mu protects the fields below. + mu sync.RWMutex + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState - // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of incomplete these waiters are notified. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil iff - // s is incomplete and resolution is not yet in progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) } -// changeState sets the entry's state to ns, notifying any waiters. +func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { + for _, callback := range e.onResolve { + callback(linkAddr, len(linkAddr) != 0) + } + e.onResolve = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// changeStateLocked sets the entry's state to ns. // // The entry's expiration is bumped up to the greater of itself and the passed // expiration; the zero value indicates immediate expiration, and is set // unconditionally - this is an implementation detail that allows for entries // to be reused. -func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { - // Notify whoever is waiting on address resolution when transitioning - // out of incomplete. - if e.s == incomplete && ns != incomplete { - for w := range e.wakers { - w.Assert() - } - e.wakers = nil - if ch := e.done; ch != nil { - close(ch) - } - e.done = nil +// +// Precondition: e.mu must be locked +func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { + if e.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.linkAddr) } if expiration.IsZero() || expiration.After(e.expiration) { @@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { e.s = ns } -func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { - delete(e.wakers, w) -} - // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is @@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { c.cache.Lock() entry := c.getOrCreateEntryLocked(k) - entry.linkAddr = v - - entry.changeState(ready, expiration) c.cache.Unlock() + + entry.mu.Lock() + defer entry.mu.Unlock() + entry.linkAddr = v + entry.changeStateLocked(ready, expiration) } // getOrCreateEntryLocked retrieves a cache entry associated with k. The @@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt var entry *linkAddrEntry if len(c.cache.table) == linkAddrCacheSize { entry = c.cache.lru.Back() + entry.mu.Lock() delete(c.cache.table, entry.addr) c.cache.lru.Remove(entry) - // Wake waiters and mark the soon-to-be-reused entry as expired. Note - // that the state passed doesn't matter when the zero time is passed. - entry.changeState(failed, time.Time{}) + // Wake waiters and mark the soon-to-be-reused entry as expired. + entry.notifyCompletionLocked("" /* linkAddr */) + entry.mu.Unlock() } else { entry = new(linkAddrEntry) } @@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { + if onResolve != nil { + onResolve(addr, true) + } return addr, nil, nil } } @@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: + case ready: if !time.Now().After(entry.expiration) { // Not expired. - switch s { - case ready: - return entry.linkAddr, nil, nil - case failed: - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + if onResolve != nil { + onResolve(entry.linkAddr, true) } + return entry.linkAddr, nil, nil } - entry.changeState(incomplete, time.Time{}) + entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: - if waker != nil { - if entry.wakers == nil { - entry.wakers = make(map[*sleep.Waker]struct{}) - } - entry.wakers[waker] = struct{}{} + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) } - if entry.done == nil { - // Address resolution needs to be initiated. - if linkRes == nil { - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - } - entry.done = make(chan struct{}) go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -// removeWaker removes a waker previously added through get(). -func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.cache.Lock() - defer c.cache.Unlock() - - if entry, ok := c.cache.table[k]; ok { - entry.removeWaker(waker) - } -} - func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check @@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link } } -// checkLinkRequest checks whether previous attempt to resolve address has succeeded -// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request -// can stop, false if another request should be sent. +// checkLinkRequest checks whether previous attempt to resolve address has +// succeeded and mark the entry accordingly. Returns true if request can stop, +// false if another request should be sent. func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { c.cache.Lock() defer c.cache.Unlock() @@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att // Entry was evicted from the cache. return true } + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: - // Entry was made ready by resolver or failed. Either way we're done. + case ready: + // Entry was made ready by resolver. case incomplete: if attempt+1 < c.resolutionAttempts { // No response yet, need to send another ARP request. return false } - // Max number of retries reached, mark entry as failed. - entry.changeState(failed, now.Add(c.ageLimit)) + // Max number of retries reached, delete entry. + entry.notifyCompletionLocked("" /* linkAddr */) + delete(c.cache.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d2e37f38d..6883045b5 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -21,7 +21,6 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -50,6 +49,7 @@ type testLinkAddressResolver struct { } func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() @@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe } func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - + var attemptedResolution bool for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err + got, ch, err := c.get(addr, linkRes, "", nil, nil) + if err == tcpip.ErrWouldBlock { + if attemptedResolution { + return got, tcpip.ErrNoLinkAddress + } + attemptedResolution = true + <-ch + continue } - s.Fetch(true) + return got, err } } @@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. + c.cache.Lock() + defer c.cache.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } } func TestCacheConcurrent(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup for r := 0; r < 16; r++ { @@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) { go func() { for _, e := range testAddrs { c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan } wg.Done() }() @@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) { } e = testAddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + c.cache.Lock() + defer c.cache.Unlock() + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) } } @@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } - -// TestCacheWaker verifies that RemoveWaker removes a waker previously added -// through get(). -func TestCacheWaker(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - // First, sanity check that wakers are working. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 1 - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[0] - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - id, ok := s.Fetch(true /* block */) - if !ok { - t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") - } - if id != wakerID { - t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) - } - - if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - } - - // Check that RemoveWaker works. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 2 // different than the ID used in the sanity check - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[1] - linkRes.onLinkAddressRequest = func() { - // Remove the waker before the linkAddrCache has the opportunity to send - // a notification. - c.removeWaker(e.addr, &w) - } - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - - if got, err := getBlocking(c, e.addr, linkRes); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Fatalf("unexpected notification from waker with id %d", id) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 31b67b987..61636cae5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -540,8 +540,8 @@ func TestDADResolve(t *testing.T) { // Make sure the right remote link address is used. snmc := header.SolicitedNodeAddr(addr1) - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { - t.Errorf("got remote link address = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } // Check NDP NS packet. @@ -577,11 +577,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) } @@ -623,11 +623,11 @@ func TestDADFail(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, @@ -1011,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, }) return stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -5197,8 +5197,8 @@ func TestRouterSolicitation(t *testing.T) { } // Make sure the right remote link address is used. - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want { - t.Errorf("got remote link address = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 317f6871d..c15f10e76 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -99,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA n.dynamic.lru.Remove(e) n.dynamic.count-- - e.dispatchRemoveEventLocked() - e.setStateLocked(Unknown) - e.notifyWakersLocked() + e.removeLocked() e.mu.Unlock() } n.cache[remoteAddr] = entry @@ -110,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA return entry } -// entry looks up the neighbor cache for translating address to link address -// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there -// is a LinkAddressResolver registered with the network protocol, the cache -// attempts to resolve the address and returns ErrWouldBlock. If a Waker is -// provided, it will be notified when address resolution is complete (success -// or not). +// entry looks up neighbor information matching the remote address, and returns +// it if readily available. +// +// Returns ErrWouldBlock if the link address is not readily available, along +// with a notification channel for the caller to block on. Triggers address +// resolution asynchronously. +// +// If onResolve is provided, it will be called either immediately, if resolution +// is not required, or when address resolution is complete, with the resolved +// link address and whether resolution succeeded. After any callbacks have been +// called, the returned notification channel is closed. +// +// NB: if a callback is provided, it should not call into the neighbor cache. // // If specified, the local address must be an address local to the interface the // neighbor cache belongs to. The local address is the source address of a // packet prompting NUD/link address resolution. // -// If address resolution is required, ErrNoLinkAddress and a notification -// channel is returned for the top level caller to block. Channel is closed -// once address resolution is complete (success or not). -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve. if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ Addr: remoteAddr, @@ -132,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA State: Static, UpdatedAtNanos: 0, } + if onResolve != nil { + onResolve(linkAddr, true) + } return e, nil, nil } @@ -149,37 +155,25 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // of packets to a neighbor. While reasserting a neighbor's reachability, // a node continues sending packets to that neighbor using the cached // link-layer address." + if onResolve != nil { + onResolve(entry.neigh.LinkAddr, true) + } return entry.neigh, nil, nil - case Unknown, Incomplete: - entry.addWakerLocked(w) - + case Unknown, Incomplete, Failed: + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) + } if entry.done == nil { // Address resolution needs to be initiated. - if linkRes == nil { - return entry.neigh, nil, tcpip.ErrNoLinkAddress - } entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock - case Failed: - return entry.neigh, nil, tcpip.ErrNoLinkAddress default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } } -// removeWaker removes a waker that has been added when link resolution for -// addr was requested. -func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { - n.mu.Lock() - if entry, ok := n.cache[addr]; ok { - delete(entry.wakers, waker) - } - n.mu.Unlock() -} - // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { n.mu.RLock() @@ -222,34 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd return } - // Notify that resolution has been interrupted, just in case the entry was - // in the Incomplete or Probe state. - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } -// removeEntryLocked removes the specified entry from the neighbor cache. -// -// Prerequisite: n.mu and entry.mu MUST be locked. -func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.count-- - } - if entry.neigh.State != Failed { - entry.dispatchRemoveEventLocked() - } - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() - - delete(n.cache, entry.neigh.Addr) -} - // removeEntry removes a dynamic or static entry by address from the neighbor // cache. Returns true if the entry was found and deleted. func (n *neighborCache) removeEntry(addr tcpip.Address) bool { @@ -264,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { entry.mu.Lock() defer entry.mu.Unlock() - n.removeEntryLocked(entry) + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + + entry.removeLocked() + delete(n.cache, entry.neigh.Addr) return true } @@ -275,9 +254,7 @@ func (n *neighborCache) clear() { for _, entry := range n.cache { entry.mu.Lock() - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 732a299f7..a2ed6ae2a 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -28,7 +28,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" ) @@ -190,15 +189,18 @@ type testNeighborResolver struct { entries *testEntryStore delay time.Duration onLinkAddressRequest func() + dropReplies bool } var _ LinkAddressResolver = (*testNeighborResolver)(nil) func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) + if !r.dropReplies { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(targetAddr) + }) + } // Execute post address resolution action, if available. if f := r.onLinkAddressRequest; f != nil { @@ -291,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -327,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -413,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -461,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -513,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { } // Expect to find only the most recent entries. The order of entries reported - // by entries() is undeterministic, so entries have to be sorted before + // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { @@ -575,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -650,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -694,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -756,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -826,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -907,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { } } -func TestNeighborCacheNotifiesWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - id, ok := s.Fetch(false /* block */) - if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) - } - if id != wakerID { - t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - -func TestNeighborCacheRemoveWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - - // Remove the waker before the neighbor cache has the opportunity to send a - // notification. - neigh.removeWaker(entry.Addr, &w) - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Errorf("unexpected notification from waker with id %d", id) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { config := DefaultNUDConfigurations() // Stay in Reachable so the cache can overflow @@ -1062,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1075,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -1129,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) { // Add a dynamic entry. entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1187,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) { } } - // Clear shoud remove both dynamic and static entries. + // Clear should remove both dynamic and static entries. neigh.clear() // Remove events dispatched from clear() have no deterministic order so they @@ -1234,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1318,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { frequentlyUsedEntry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1330,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1373,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1381,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1435,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { @@ -1494,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { go func(entry NeighborEntry) { defer wg.Done() if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } - // Wait for all gorountines to send a request + // Wait for all goroutines to send a request wg.Wait() // Process all the requests for a single entry concurrently @@ -1509,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than can fit in // the cache. Our eviction strategy requires that the last entries are // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry for i := store.size() - neighborCacheSize; i < store.size(); i++ { @@ -1547,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) { // Add an entry entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) - } - if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) + t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1578,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1587,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) { { entry, ok := store.entry(1) if !ok { - t.Fatalf("store.entry(1) not found") + t.Fatal("store.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } @@ -1604,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) { { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1612,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1622,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1630,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1654,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { }, } - // First, sanity check that resolution is working entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + // First, sanity check that resolution is working + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } - clock.Advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1673,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } - // Verify that address resolution for an unknown address returns ErrNoLinkAddress + // Verify address resolution fails for an unknown address. before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } maxAttempts := neigh.config().MaxUnicastProbes @@ -1714,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } +} + +// TestNeighborCacheRetryResolution simulates retrying communication after +// failing to perform address resolution. +func TestNeighborCacheRetryResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + // Simulate a faulty link. + dropReplies: true, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatal("store.entry(0) not found") + } + + // Perform address resolution with a faulty link, which will fail. + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } + } + + // Verify the entry is in Failed state. + wantEntries := []NeighborEntry{ + { + Addr: entry.Addr, + LinkAddr: "", + State: Failed, + }, + } + if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // Retry address resolution with a working link. + linkRes.dropReplies = false + { + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + if incompleteEntry.State != Incomplete { + t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) + } + clock.Advance(typicalLatency) + + select { + case <-ch: + if !ok { + t.Fatal("expected successful address resolution") + } + reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if err != nil { + t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + } + if reachableEntry.Addr != entry.Addr { + t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) + } + if reachableEntry.LinkAddr != entry.LinkAddr { + t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) + } + if reachableEntry.State != Reachable { + t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) + } + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } } @@ -1742,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ Addr: testEntryBroadcastAddr, @@ -1750,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1775,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + b.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } - if doneCh != nil { - <-doneCh + + select { + case <-ch: + default: + b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 32399b4f5..75afb3001 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -19,7 +19,6 @@ import ( "sync" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -67,8 +66,7 @@ const ( // Static describes entries that have been explicitly added by the user. They // do not expire and are not deleted until explicitly removed. Static - // Failed means traffic should not be sent to this neighbor since attempts of - // reachability have returned inconclusive. + // Failed means recent attempts of reachability have returned inconclusive. Failed ) @@ -93,16 +91,13 @@ type neighborEntry struct { neigh NeighborEntry - // wakers is a set of waiters for address resolution result. Anytime state - // transitions out of incomplete these waiters are notified. It is nil iff - // address resolution is ongoing and no clients are waiting for the result. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil - // iff nudState is not Reachable and address resolution is not yet in - // progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) + isRouter bool job *tcpip.Job } @@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd } } -// addWaker adds w to the list of wakers waiting for address resolution. -// Assumes the entry has already been appropriately locked. -func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { - if w == nil { - return - } - if e.wakers == nil { - e.wakers = make(map[*sleep.Waker]struct{}) - } - e.wakers[w] = struct{}{} -} - -// notifyWakersLocked notifies those waiting for address resolution, whether it -// succeeded or failed. Assumes the entry has already been appropriately locked. -func (e *neighborEntry) notifyWakersLocked() { - for w := range e.wakers { - w.Assert() +// notifyCompletionLocked notifies those waiting for address resolution, with +// the link address if resolution completed successfully. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + for _, callback := range e.onResolve { + callback(e.neigh.LinkAddr, succeeded) } - e.wakers = nil + e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil @@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborAdded(e.nic.id, e.neigh) @@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborChanged(e.nic.id, e.neigh) @@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry // has been removed. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } +// cancelJobLocked cancels the currently scheduled action, if there is one. +// Entries in Unknown, Stale, or Static state do not have a scheduled action. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) cancelJobLocked() { + if job := e.job; job != nil { + job.Cancel() + } +} + +// removeLocked prepares the entry for removal. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) removeLocked() { + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.dispatchRemoveEventLocked() + e.cancelJobLocked() + e.notifyCompletionLocked(false /* succeeded */) +} + // setStateLocked transitions the entry to the specified state immediately. // // Follows the logic defined in RFC 4861 section 7.3.3. // -// e.mu MUST be locked. +// Precondition: e.mu MUST be locked. func (e *neighborEntry) setStateLocked(next NeighborState) { - // Cancel the previously scheduled action, if there is one. Entries in - // Unknown, Stale, or Static state do not have scheduled actions. - if timer := e.job; timer != nil { - timer.Cancel() - } + e.cancelJobLocked() prev := e.neigh.State e.neigh.State = next @@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { e.job.Schedule(immediateDuration) case Failed: - e.notifyWakersLocked() - e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() { - e.nic.neigh.removeEntryLocked(e) - }) - e.job.Schedule(config.UnreachableTime) + e.notifyCompletionLocked(false /* succeeded */) case Unknown, Stale, Static: // Do nothing @@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { + case Failed: + e.nic.stats.Neighbor.FailedEntryLookups.Increment() + + fallthrough case Unknown: e.neigh.State = Incomplete e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() @@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // implementation may find it convenient in some cases to return errors // to the sender by taking the offending packet, generating an ICMP // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 + // error-handling routines." - RFC 4861 section 2.1 e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return @@ -349,8 +358,6 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Incomplete, Reachable, Delay, Probe, Static: // Do nothing - case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() default: panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } @@ -360,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // Neighbor Solicitation for ARP or NDP, respectively). // // Follows the logic defined in RFC 4861 section 7.2.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // Probes MUST be silently discarded if the target address is tentative, does // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These // checks MUST be done by the NetworkEndpoint. switch e.neigh.State { - case Unknown, Incomplete, Failed: + case Unknown, Failed: e.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) - e.notifyWakersLocked() e.dispatchAddEventLocked() + case Incomplete: + // "If an entry already exists, and the cached link-layer address + // differs from the one in the received Source Link-Layer option, the + // cached address should be replaced by the received address, and the + // entry's reachability state MUST be set to STALE." + // - RFC 4861 section 7.2.3 + e.neigh.LinkAddr = remoteLinkAddr + e.setStateLocked(Stale) + e.notifyCompletionLocked(true /* succeeded */) + e.dispatchChangeEventLocked() + case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr @@ -404,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // not be possible. SEND uses RSA key pairs to produce Cryptographically // Generated Addresses (CGA), as defined in RFC 3972. This ensures that the // claimed source of an NDP message is the owner of the claimed address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { switch e.neigh.State { case Incomplete: @@ -422,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) // "Note that the Override flag is ignored if the entry is in the // INCOMPLETE state." - RFC 4861 section 7.2.5 @@ -457,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) if !wasReachable { e.dispatchChangeEventLocked() } @@ -495,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // handleUpperLevelConfirmationLocked processes an incoming upper-level protocol // (e.g. TCP acknowledgements) reachability confirmation. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: @@ -512,23 +535,3 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() { panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } } - -// doubleLock combines two locks into one while maintaining lock ordering. -// -// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed -// neighbor is allowed. -type doubleLock struct { - first, second sync.Locker -} - -// Lock locks both locks in order: first then second. -func (l *doubleLock) Lock() { - l.first.Lock() - l.second.Lock() -} - -// Unlock unlocks both locks in reverse order: second then first. -func (l *doubleLock) Unlock() { - l.second.Unlock() - l.first.Unlock() -} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index c497d3932..ec34ffa5a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -25,7 +25,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -73,36 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option { // The following unit tests exercise every state transition and verify its // behavior with RFC 4681. // -// | From | To | Cause | Action | Event | -// | ========== | ========== | ========================================== | =============== | ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | -// | Unknown | Stale | Probe w/ unknown address | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | -// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | -// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | -// | Reachable | Stale | Reachable timer expired | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | -// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | -// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | -// | Stale | Delay | Packet queued | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | Changed | -// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | Changed | -// | Delay | Probe | Delay timer expired | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | Changed | -// | Probe | Probe | Retransmit timer expired | Send probe | Changed | -// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Failed | Failed | Packet queued | | | -// | Failed | | Unreachability timer expired | Delete entry | | +// | From | To | Cause | Update | Action | Event | +// | ========== | ========== | ========================================== | ======== | ===========| ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | +// | Unknown | Stale | Probe | | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | +// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | +// | Reachable | Stale | Reachable timer expired | | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | +// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Stale | Stale | Override confirmation | LinkAddr | | Changed | +// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | +// | Stale | Delay | Packet sent | | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | | Changed | +// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | +// | Delay | Probe | Delay timer expired | | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | +// | Probe | Probe | Retransmit timer expired | | | Changed | +// | Probe | Failed | Max probes sent without reply | | Notify | Removed | +// | Failed | Incomplete | Packet queued | | Send probe | Added | type testEntryEventType uint8 @@ -258,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -291,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -320,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -367,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -406,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() @@ -560,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) { nudDisp.mu.Unlock() } -// TestEntryAddsAndClearsWakers verifies that wakers are added when -// addWakerLocked is called and cleared when address resolution finishes. In -// this case, address resolution will finish when transitioning from Incomplete -// to Reachable. -func TestEntryAddsAndClearsWakers(t *testing.T) { +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + } e.mu.Unlock() runImmediatelyScheduledJobs(clock) @@ -593,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Lock() - if got := e.wakers; got != nil { - t.Errorf("got e.wakers = %v, want = nil", got) - } - e.addWakerLocked(&w) - if got, want := w.IsAsserted(), false; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) - } - if e.wakers == nil { - t.Error("expected e.wakers to be non-nil") - } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, - IsRouter: false, + IsRouter: true, }) - if e.wakers != nil { - t.Errorf("got e.wakers = %v, want = nil", e.wakers) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } - if got, want := w.IsAsserted(), true; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) } e.mu.Unlock() @@ -643,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { +func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -663,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - linkRes.mu.Unlock() e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, + Solicited: false, Override: false, - IsRouter: true, + IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) - } - if !e.isRouter { - t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -698,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: entryTestLinkAddr1, - State: Reachable, + State: Stale, }, }, } @@ -709,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToStale(t *testing.T) { +func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -736,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) { } e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) + e.handleProbeLocked(entryTestLinkAddr1) if e.neigh.State != Stale { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } @@ -780,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -841,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } e.mu.Unlock() } @@ -885,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) @@ -932,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() } @@ -1083,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() } @@ -2381,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } e.mu.Unlock() @@ -2447,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.mu.Unlock() } @@ -2505,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2620,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2740,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2836,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2964,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -3101,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() @@ -3435,212 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryFailedToFailed(t *testing.T) { +func TestEntryFailedToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) - } - // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in // their expected state. e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, { - EventType: entryTestRemoved, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, } - nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } - nudDisp.mu.Unlock() - - failedLookups := e.nic.stats.Neighbor.FailedEntryLookups - if got := failedLookups.Value(); got != 0 { - t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got) + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } e.mu.Lock() - // Verify queuing a packet to the entry immediately fails. - e.handlePacketQueuedLocked(entryTestAddr2) - state := e.neigh.State - e.mu.Unlock() - if state != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", state, Failed) - } - - if got := failedLookups.Value(); got != 1 { - t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got) - } -} - -func TestEntryFailedGetsDeleted(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - } - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - // The next three probe are sent in Probe. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { @@ -3653,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) { }, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, + EventType: entryTestRemoved, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, { - EventType: entryTestRemoved, + EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, } @@ -3694,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - // Verify the cache no longer contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { - t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) - } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5887aa1ed..4a34805b5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -20,7 +20,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -54,9 +53,9 @@ type NIC struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + // packetEPs is protected by mu, but the contained packetEndpointList are + // not. + packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } } @@ -82,6 +81,39 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } +type packetEndpointList struct { + mu sync.RWMutex + + // eps is protected by mu, but the contained PacketEndpoint values are not. + eps []PacketEndpoint +} + +func (p *packetEndpointList) add(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.eps = append(p.eps, ep) +} + +func (p *packetEndpointList) remove(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + for i, epOther := range p.eps { + if epOther == ep { + p.eps = append(p.eps[:i], p.eps[i+1:]...) + break + } + } +} + +// forEach calls fn with each endpoints in p while holding the read lock on p. +func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, ep := range p.eps { + fn(ep) + } +} + // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -102,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -125,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil + nic.mu.packetEPs[netNum] = new(packetEndpointList) nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } @@ -172,7 +204,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -184,6 +216,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. @@ -258,15 +294,17 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // the same unresolved IP address, and transmit the saved // packet when the address has been resolved. // - // RFC 4861 section 5.2 (for IPv6): - // Once the IP address of the next-hop node is known, the sender - // examines the Neighbor Cache for link-layer information about that - // neighbor. If no entry exists, the sender creates one, sets its state - // to INCOMPLETE, initiates Address Resolution, and then queues the data - // packet pending completion of address resolution. + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and MAY + // contain more. However, the number of queued packets per neighbor SHOULD + // be limited to some small value. When a queue overflows, the new arrival + // SHOULD replace the oldest entry. Once address resolution completes, the + // node transmits any queued packets. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { - r := r.Clone() + r.Acquire() n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } @@ -279,7 +317,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // WritePacketToRemote implements NetworkInterface. func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { r := Route{ - NetProto: protocol, + routeInfo: routeInfo{ + NetProto: protocol, + }, } r.ResolveWith(remoteLinkAddr) return n.writePacket(&r, gso, protocol, pkt) @@ -508,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { return n.neigh.entries(), nil } -func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { - if n.neigh == nil { - return - } - - n.neigh.removeWaker(addr, w) -} - func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { if n.neigh == nil { return tcpip.ErrNotSupported @@ -634,15 +666,23 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? - packetEPs := n.mu.packetEPs[protocol] - // Add any other packet type sockets that may be listening for all protocols. - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) + protoEPs := n.mu.packetEPs[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + // Deliver to interested packet endpoints without holding NIC lock. + deliverPacketEPs := func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketHost ep.HandlePacket(n.id, local, protocol, p) } + if protoEPs != nil { + protoEPs.forEach(deliverPacketEPs) + } + if anyEPs != nil { + anyEPs.forEach(deliverPacketEPs) + } // Parse headers. netProto := n.stack.NetworkProtocolInstance(protocol) @@ -683,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + eps.forEach(func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) - } + }) } // DeliverTransportPacket delivers the packets to the appropriate transport @@ -845,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa if !ok { return tcpip.ErrNotSupported } - n.mu.packetEPs[netProto] = append(eps, ep) + eps.add(ep) return nil } @@ -858,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep if !ok { return } - - for i, epOther := range eps { - if epOther == ep { - n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) - return - } - } + eps.remove(ep) } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index ab629b3a4..12d67409a 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -109,14 +109,6 @@ const ( // // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. defaultMaxReachbilityConfirmations = 3 - - // defaultUnreachableTime is the default duration for how long an entry will - // remain in the FAILED state before being removed from the neighbor cache. - // - // Note, there is no equivalent protocol constant defined in RFC 4861. It - // leaves the specifics of any garbage collection mechanism up to the - // implementation. - defaultUnreachableTime = 5 * time.Second ) // NUDDispatcher is the interface integrators of netstack must implement to @@ -278,10 +270,6 @@ type NUDConfigurations struct { // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD // configuration option is necessary. MaxReachabilityConfirmations uint32 - - // UnreachableTime describes how long an entry will remain in the FAILED - // state before being removed from the neighbor cache. - UnreachableTime time.Duration } // DefaultNUDConfigurations returns a NUDConfigurations populated with default @@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations { MaxUnicastProbes: defaultMaxUnicastProbes, MaxAnycastDelayTime: defaultMaxAnycastDelayTime, MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, - UnreachableTime: defaultUnreachableTime, } } @@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() { if c.MaxUnicastProbes == 0 { c.MaxUnicastProbes = defaultMaxUnicastProbes } - if c.UnreachableTime == 0 { - c.UnreachableTime = defaultUnreachableTime - } } // calcMaxRandomFactor calculates the maximum value of the random factor used @@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration { s.config.BaseReachableTime != s.prevBaseReachableTime || s.config.MinRandomFactor != s.prevMinRandomFactor || s.config.MaxRandomFactor != s.prevMaxRandomFactor { - return s.recomputeReachableTimeLocked() + s.recomputeReachableTimeLocked() } return s.reachableTime } @@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration { // random value gets re-computed at least once every few hours. // // s.mu MUST be locked for writing. -func (s *NUDState) recomputeReachableTimeLocked() time.Duration { +func (s *NUDState) recomputeReachableTimeLocked() { s.prevBaseReachableTime = s.config.BaseReachableTime s.prevMinRandomFactor = s.config.MinRandomFactor s.prevMaxRandomFactor = s.config.MaxRandomFactor @@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration { } s.expiration = time.Now().Add(2 * time.Hour) - return s.reachableTime } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 8cffb9fc6..7bca1373e 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -37,7 +37,6 @@ const ( defaultMaxUnicastProbes = 3 defaultMaxAnycastDelayTime = time.Second defaultMaxReachbilityConfirmations = 3 - defaultUnreachableTime = 5 * time.Second defaultFakeRandomNum = 0.5 ) @@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { } } -func TestNUDConfigurationsUnreachableTime(t *testing.T) { - tests := []struct { - name string - unreachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - unreachableTime: 0, - want: defaultUnreachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - unreachableTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.UnreachableTime = test.unreachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) - } - if got := sc.UnreachableTime; got != test.want { - t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) - } - }) - } -} - // TestNUDStateReachableTime verifies the correctness of the ReachableTime // computation. func TestNUDStateReachableTime(t *testing.T) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 5d364a2b0..4a3adcf33 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro for _, p := range packets { if cancelled { p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if _, err := p.route.Resolve(nil); err != nil { + } else if p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b334e27c4..7e83b7fbb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -799,19 +798,26 @@ type LinkAddressCache interface { // AddLinkAddress adds a link address to the cache. AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) - // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). - // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver - // registered with the network protocol, the cache attempts to resolve the address - // and returns ErrWouldBlock. Waker is notified when address resolution is - // complete (success or not). + // GetLinkAddress finds the link address corresponding to the remote address + // (e.g. IP -> MAC). // - // If address resolution is required, ErrNoLinkAddress and a notification channel is - // returned for the top level caller to block. Channel is closed once address resolution - // is complete (success or not). - GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) - - // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + // Returns a link address for the remote address, if readily available. + // + // Returns ErrWouldBlock if the link address is not readily available, along + // with a notification channel for the caller to block on. Triggers address + // resolution asynchronously. + // + // If onResolve is provided, it will be called either immediately, if + // resolution is not required, or when address resolution is complete, with + // the resolved link address and whether resolution succeeded. After any + // callbacks have been called, the returned notification channel is closed. + // + // If specified, the local address must be an address local to the interface + // the neighbor cache belongs to. The local address is the source address of + // a packet prompting NUD/link address resolution. + // + // TODO(gvisor.dev/issue/5151): Don't return the link address. + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index de5fe6ffe..b0251d0b4 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -31,24 +30,7 @@ import ( // // TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { - // RemoteAddress is the final destination of the route. - RemoteAddress tcpip.Address - - // LocalAddress is the local address where the route starts. - LocalAddress tcpip.Address - - // LocalLinkAddress is the link-layer (MAC) address of the - // where the route starts. - LocalLinkAddress tcpip.LinkAddress - - // NextHop is the next node in the path to the destination. - NextHop tcpip.Address - - // NetProto is the network-layer protocol. - NetProto tcpip.NetworkProtocolNumber - - // Loop controls where WritePacket should send packets. - Loop PacketLooping + routeInfo // localAddressNIC is the interface the address is associated with. // TODO(gvisor.dev/issue/4548): Remove this field once we can query the @@ -78,6 +60,45 @@ type Route struct { linkRes LinkAddressResolver } +type routeInfo struct { + // RemoteAddress is the final destination of the route. + RemoteAddress tcpip.Address + + // LocalAddress is the local address where the route starts. + LocalAddress tcpip.Address + + // LocalLinkAddress is the link-layer (MAC) address of the + // where the route starts. + LocalLinkAddress tcpip.LinkAddress + + // NextHop is the next node in the path to the destination. + NextHop tcpip.Address + + // NetProto is the network-layer protocol. + NetProto tcpip.NetworkProtocolNumber + + // Loop controls where WritePacket should send packets. + Loop PacketLooping +} + +// RouteInfo contains all of Route's exported fields. +type RouteInfo struct { + routeInfo + + // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + RemoteLinkAddress tcpip.LinkAddress +} + +// GetFields returns a RouteInfo with all of r's exported fields. This allows +// callers to store the route's fields without retaining a reference to it. +func (r *Route) GetFields() RouteInfo { + return RouteInfo{ + routeInfo: r.routeInfo, + RemoteLinkAddress: r.RemoteLinkAddress(), + } +} + // constructAndValidateRoute validates and initializes a route. It takes // ownership of the provided local address. // @@ -152,13 +173,15 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { r := &Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - outgoingNIC: outgoingNIC, - Loop: loop, + routeInfo: routeInfo{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + Loop: loop, + }, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, } r.mu.Lock() @@ -264,22 +287,21 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in -// case address resolution requires blocking, e.g. wait for ARP reply. Waker is -// notified when address resolution is complete (success or not). +// Resolve attempts to resolve the link address if necessary. // -// If address resolution is required, ErrNoLinkAddress and a notification channel is -// returned for the top level caller to block. Channel is closed once address resolution -// is complete (success or not). -// -// The NIC r uses must not be locked. -func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { +// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. +// waiting for ARP reply). If address resolution is required, a notification +// channel is also returned for the caller to block on. The channel is closed +// once address resolution is complete (successful or not). If a callback is +// provided, it will be called when address resolution is complete, regardless +// of success or failure. +func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { r.mu.Lock() - defer r.mu.Unlock() if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. + r.mu.Unlock() return nil, nil } @@ -288,6 +310,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { r.mu.remoteLinkAddress = r.LocalLinkAddress + r.mu.Unlock() return nil, nil } nextAddr = r.RemoteAddress @@ -300,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } + // Increment the route's reference count because finishResolution retains a + // reference to the route and releases it when called. + r.acquireLocked() + r.mu.Unlock() + + finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { + if ok { + r.ResolveWith(linkAddress) + } + if afterResolve != nil { + afterResolve() + } + r.Release() + } + if neigh := r.outgoingNIC.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) + _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) if err != nil { return ch, err } - r.mu.remoteLinkAddress = entry.LinkAddr return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) + _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution) if err != nil { return ch, err } - r.mu.remoteLinkAddress = linkAddr return nil, nil } -// RemoveWaker removes a waker that has been added in Resolve(). -func (r *Route) RemoveWaker(waker *sleep.Waker) { - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress - } - - if neigh := r.outgoingNIC.neigh; neigh != nil { - neigh.removeWaker(nextAddr, waker) - return - } - - r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) -} - // local returns true if the route is a local route. func (r *Route) local() bool { return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() @@ -419,46 +440,31 @@ func (r *Route) MTU() uint32 { return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } -// Release frees all resources associated with the route. +// Release decrements the reference counter of the resources associated with the +// route. func (r *Route) Release() { r.mu.Lock() defer r.mu.Unlock() - if r.mu.localAddressEndpoint != nil { - r.mu.localAddressEndpoint.DecRef() - r.mu.localAddressEndpoint = nil + if ep := r.mu.localAddressEndpoint; ep != nil { + ep.DecRef() } } -// Clone clones the route. -func (r *Route) Clone() *Route { +// Acquire increments the reference counter of the resources associated with the +// route. +func (r *Route) Acquire() { r.mu.RLock() defer r.mu.RUnlock() + r.acquireLocked() +} - newRoute := &Route{ - RemoteAddress: r.RemoteAddress, - LocalAddress: r.LocalAddress, - LocalLinkAddress: r.LocalLinkAddress, - NextHop: r.NextHop, - NetProto: r.NetProto, - Loop: r.Loop, - localAddressNIC: r.localAddressNIC, - outgoingNIC: r.outgoingNIC, - linkCache: r.linkCache, - linkRes: r.linkRes, - } - - newRoute.mu.Lock() - defer newRoute.mu.Unlock() - newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint - if newRoute.mu.localAddressEndpoint != nil { - if !newRoute.mu.localAddressEndpoint.IncRef() { - panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress)) +func (r *Route) acquireLocked() { + if ep := r.mu.localAddressEndpoint; ep != nil { + if !ep.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) } } - newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress - - return newRoute } // Stack returns the instance of the Stack that owns this route. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index dc4f5b3e7..114643b03 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,7 +29,6 @@ import ( "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -171,6 +170,9 @@ type TCPSenderState struct { // Outstanding is the number of packets in flight. Outstanding int + // SackedOut is the number of packets which have been selectively acked. + SackedOut int + // SndWnd is the send window size in bytes. SndWnd seqnum.Size @@ -1517,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() nic := s.nics[nicID] if nic == nil { @@ -1528,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve) } // Neighbors returns all IP to MAC address associations. @@ -1544,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { return nic.neighbors() } -// RemoveWaker removes a waker that has been added when link resolution for -// addr was requested. -func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { - if s.useNeighborCache { - s.mu.RLock() - nic, ok := s.nics[nicID] - s.mu.RUnlock() - - if ok { - nic.removeWaker(addr, waker) - } - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - if nic := s.nics[nicID]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.removeWaker(fullAddr, waker) - } -} - // AddStaticNeighbor statically associates an IP address to a MAC address. func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { s.mu.RLock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 457990945..856ebf6d4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1602,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = header.IPv4Any + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1656,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1666,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic2Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1682,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -2726,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 - lifetimeSeconds = 9999 + globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") + ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") + toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + nicID = 1 + lifetimeSeconds = 9999 ) prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) @@ -2744,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix nicAddrs []tcpip.Address slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix - connectAddr tcpip.Address + remoteAddr tcpip.Address expectedLocalAddr tcpip.Address }{ - // Test Rule 1 of RFC 6724 section 5. + // Test Rule 1 of RFC 6724 section 5 (prefer same address). { name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, - // Test Rule 2 of RFC 6724 section 5. + // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope). { name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, + + // Test Rule 6 of 6724 section 5 (prefer matching label). { name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, + { + name: "Toredo most preferred (first address)", + nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "Toredo most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "6To4 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "6To4 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, - // Test Rule 7 of RFC 6724 section 5. + // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses). { name: "Temp Global most preferred (last address)", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, { name: "Temp Global most preferred (first address)", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, + // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix). + { + name: "Longest prefix matched most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr2, globalAddr1}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + { + name: "Longest prefix matched most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, globalAddr2}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + // Test returning the endpoint that is closest to the front when // candidate addresses are "equal" from the perspective of RFC 6724 // section 5. { name: "Unique Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Link Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local for Unique Local", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: uniqueLocalAddr2, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Temp Global for Global", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, - connectAddr: globalAddr1, + remoteAddr: globalAddr1, expectedLocalAddr: tempGlobalAddr2, }, } @@ -2898,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) @@ -2923,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.FailNow() } - if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } + + addressableEndpoint, ok := netEP.(stack.AddressableEndpoint) + if !ok { + t.Fatal("network endpoint is not addressable") + } + + addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */) + if addressEP == nil { + t.Fatal("expected a non-nil address endpoint") + } + defer addressEP.DecRef() + + if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr { t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) } }) @@ -4204,8 +4286,8 @@ func TestWritePacketToRemote(t *testing.T) { if got, want := pkt.Proto, test.protocol; got != want { t.Fatalf("pkt.Proto = %d, want %d", got, want) } - if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want { - t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) + if pkt.Route.RemoteLinkAddress != linkAddr2 { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) } if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2cdb5ca79..737d8d912 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -141,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -308,9 +308,8 @@ func TestBindToDeviceDistribution(t *testing.T) { defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) - bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) + if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index d9769e47d..dd552b8b9 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -77,6 +77,7 @@ func (f *fakeTransportEndpoint) Abort() { } func (f *fakeTransportEndpoint) Close() { + // TODO(gvisor.dev/issue/5153): Consider retaining the route. f.route.Release() } @@ -109,8 +110,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. Currently not supported. @@ -146,16 +147,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return tcpip.ErrNoRoute } - defer r.Release() // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { + r.Release() return err } - f.route = r.Clone() + f.route = r return nil } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2bd472811..ef0f51f1a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -49,8 +49,9 @@ const ipv4AddressSize = 4 // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // -// Note: to support save / restore, it is important that all tcpip errors have -// distinct error messages. +// All errors must have unique msg strings. +// +// +stateify savable type Error struct { msg string @@ -257,6 +258,44 @@ func (a Address) Unspecified() bool { return true } +// MatchingPrefix returns the matching prefix length in bits. +// +// Panics if b and a have different lengths. +func (a Address) MatchingPrefix(b Address) uint8 { + const bitsInAByte = 8 + + if len(a) != len(b) { + panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b)) + } + + var prefix uint8 + for i := range a { + aByte := a[i] + bByte := b[i] + + if aByte == bByte { + prefix += bitsInAByte + continue + } + + // Count the remaining matching bits in the byte from MSbit to LSBbit. + mask := uint8(1) << (bitsInAByte - 1) + for { + if aByte&mask == bByte&mask { + prefix++ + mask >>= 1 + continue + } + + break + } + + break + } + + return prefix +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -491,6 +530,17 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is + // set. + HasOriginalDstAddress bool + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress FullAddress + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr *SockError } // PacketOwner is used to get UID and GID of the packet. @@ -545,7 +595,7 @@ type Endpoint interface { // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. - Peek([][]byte) (int64, ControlMessages, *Error) + Peek([][]byte) (int64, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -905,14 +955,6 @@ type SettableSocketOption interface { isSettableSocketOption() } -// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets -// should bind only on a specific NIC. -type BindToDeviceOption NICID - -func (*BindToDeviceOption) isGettableSocketOption() {} - -func (*BindToDeviceOption) isSettableSocketOption() {} - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -1087,14 +1129,6 @@ type RemoveMembershipOption MembershipOption func (*RemoveMembershipOption) isSettableSocketOption() {} -// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP out-of-band data is delivered along with the normal in-band data. -type OutOfBandInlineOption int - -func (*OutOfBandInlineOption) isGettableSocketOption() {} - -func (*OutOfBandInlineOption) isSettableSocketOption() {} - // SocketDetachFilterOption is used by SetSockOpt to detach a previously attached // classic BPF filter on a given endpoint. type SocketDetachFilterOption int @@ -1144,10 +1178,6 @@ type LingerOption struct { Timeout time.Duration } -func (*LingerOption) isGettableSocketOption() {} - -func (*LingerOption) isSettableSocketOption() {} - // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index c461da137..9bd563c46 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -270,3 +270,43 @@ func TestAddressUnspecified(t *testing.T) { }) } } + +func TestAddressMatchingPrefix(t *testing.T) { + tests := []struct { + addrA Address + addrB Address + prefix uint8 + }{ + { + addrA: "\x01\x01", + addrB: "\x01\x01", + prefix: 16, + }, + { + addrA: "\x01\x01", + addrB: "\x01\x00", + prefix: 15, + }, + { + addrA: "\x01\x01", + addrB: "\x81\x00", + prefix: 0, + }, + { + addrA: "\x01\x01", + addrB: "\x01\x80", + prefix: 8, + }, + { + addrA: "\x01\x01", + addrB: "\x02\x80", + prefix: 6, + }, + } + + for _, test := range tests { + if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix { + t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix) + } + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 8be791a00..2e59f6a42 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -96,11 +96,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -272,11 +272,11 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: uint16(payloadLen), + TransportProtocol: udp.ProtocolNumber, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 94fcd72d9..d1e4a7cb7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -75,8 +75,6 @@ type endpoint struct { route *stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -332,21 +330,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.SocketDetachFilterOption: - return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - } return nil } @@ -399,16 +388,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -524,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: r.LocalAddress, @@ -539,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } e.ID = id - e.route = r.Clone() + e.route = r e.RegisterNICID = nicID e.state = stateConnected diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 3666bac0f..e5e247342 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -85,8 +85,6 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -206,8 +204,8 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha } // Peek implements tcpip.Endpoint.Peek. -func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be @@ -306,16 +304,10 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - ep.mu.Lock() - ep.linger = *v - ep.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -374,18 +366,16 @@ func (ep *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (ep *endpoint) UpdateLastError(err *tcpip.Error) { + ep.lastErrorMu.Lock() + ep.lastError = err + ep.lastErrorMu.Unlock() +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - ep.mu.Lock() - *o = ep.linger - ep.mu.Unlock() - return nil - - default: - return tcpip.ErrNotSupported - } + return tcpip.ErrNotSupported } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 0840a4b3d..7befcfc9b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -85,8 +85,6 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -227,6 +225,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrInvalidOptionValue } + if opts.To != nil { + // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + return 0, nil, tcpip.ErrInvalidOptionValue + } + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -256,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } e.mu.RLock() + defer e.mu.RUnlock() if e.closed { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - e.mu.RUnlock() return 0, nil, err } @@ -273,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -295,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if e.route.IsResolutionRequired() { - savedRoute := e.route - // Promote lock to exclusive if using a shared route, - // given that it may need to change in finishWrite. - e.mu.RUnlock() - e.mu.Lock() - - // Make sure that the route didn't change during the - // time we didn't hold the lock. - if !e.connected || savedRoute != e.route { - e.mu.Unlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - - n, ch, err := e.finishWrite(payloadBytes, savedRoute) - e.mu.Unlock() - return n, ch, err - } - - n, ch, err := e.finishWrite(payloadBytes, e.route) - e.mu.RUnlock() - return n, ch, err + return e.finishWrite(payloadBytes, e.route) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } @@ -335,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - e.mu.RUnlock() return 0, nil, err } n, ch, err := e.finishWrite(payloadBytes, route) route.Release() - e.mu.RUnlock() return n, ch, err } @@ -386,8 +364,8 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -397,6 +375,11 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + return tcpip.ErrAddressFamilyNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -425,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer route.Release() if e.associated { // Re-register the endpoint with the appropriate NIC. if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + route.Release() return err } e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) @@ -437,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Save the route we've connected via. - e.route = route.Clone() + e.route = route e.connected = true return nil @@ -520,16 +503,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -581,16 +558,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -625,6 +593,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + e.mu.RLock() e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -637,6 +606,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // sockets. if e.rcvClosed || !e.associated { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() return @@ -644,6 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return @@ -655,11 +626,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // If bound to a NIC, only accept data for that NIC. if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() + e.mu.RUnlock() return } // If bound to an address, only accept data for that address. if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } } @@ -668,6 +641,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // connected to. if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } @@ -702,6 +676,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() + e.mu.RUnlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3e1041cbe..2d96a65bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() - s := sleep.Sleeper{} + var s sleep.Sleeper s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index c944dccc0..0dc710276 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error { func (h *handshake) resolveRoute() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resolutionWaker := &sleep.Waker{} s.AddWaker(resolutionWaker, wakerForResolution) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Initial action is to resolve route. index := wakerForResolution + attemptedResolution := false for { switch index { case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { - if err == tcpip.ErrNoLinkAddress { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - } else if err != nil { + if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { + if err != nil { h.ep.stats.SendErrors.NoRoute.Increment() } // Either success (err == nil) or failure. return err } + if attemptedResolution { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + return tcpip.ErrNoLinkAddress + } + attemptedResolution = true // Resolution not completed. Keep trying... case wakerForNotification: n := h.ep.fetchNotifications() if n¬ifyClose != 0 { - h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error { // complete completes the TCP 3-way handshake initiated by h.start(). func (h *handshake) complete() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resendWaker := sleep.Waker{} s.AddWaker(&resendWaker, wakerForResend) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Initialize the sleeper based on the wakers in funcs. - s := sleep.Sleeper{} + var s sleep.Sleeper for i := range funcs { s.AddWaker(funcs[i].w, i) } @@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { const notification = 2 const timeWaitDone = 3 - s := sleep.Sleeper{} + var s sleep.Sleeper defer s.Done() s.AddWaker(&e.newSegmentWaker, newSegment) s.AddWaker(&e.notificationWaker, notification) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 87eda2efb..6e3c8860e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -502,9 +502,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // bindToDevice is set to the NIC on which to bind or disabled if 0. - bindToDevice tcpip.NICID - // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -674,9 +671,6 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -1040,7 +1034,8 @@ func (e *endpoint) Close() { return } - if e.linger.Enabled && e.linger.Timeout == 0 { + linger := e.SocketOptions().GetLinger() + if linger.Enabled && linger.Timeout == 0 { s := e.EndpointState() isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv if isResetState { @@ -1305,6 +1300,15 @@ func (e *endpoint) LastError() *tcpip.Error { return e.lastErrorLocked() } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.LockUser() + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + e.UnlockUser() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1498,7 +1502,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1506,10 +1510,10 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.hardErrorLocked() + return 0, e.hardErrorLocked() } e.stats.ReadErrors.InvalidEndpointState.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } e.rcvListMu.Lock() @@ -1518,9 +1522,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + return 0, tcpip.ErrClosedForReceive } - return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + return 0, tcpip.ErrWouldBlock } // Make a copy of vec so we can modify the slide headers. @@ -1535,7 +1539,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro for len(v) > 0 { if len(vec) == 0 { - return num, tcpip.ControlMessages{}, nil + return num, nil } if len(vec[0]) == 0 { vec = vec[1:] @@ -1550,7 +1554,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } } - return num, tcpip.ControlMessages{}, nil + return num, nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1814,18 +1818,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.LockUser() - e.bindToDevice = id - e.UnlockUser() - case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(*v) @@ -1838,9 +1837,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - case *tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(*v) @@ -1909,11 +1905,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.LockUser() - e.linger = *v - e.UnlockUser() - default: return nil } @@ -2014,11 +2005,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case *tcpip.BindToDeviceOption: - e.LockUser() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.UnlockUser() - case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.LockUser() @@ -2046,10 +2032,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - *o = 1 - case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc @@ -2078,11 +2060,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { Port: port, } - case *tcpip.LingerOption: - e.LockUser() - *o = e.linger - e.UnlockUser() - default: return tcpip.ErrUnknownProtocolOption } @@ -2230,11 +2207,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } } + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { if err != tcpip.ErrPortInUse || !reuse { return false, nil } @@ -2272,15 +2250,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { return false, nil } } id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) if err == tcpip.ErrPortInUse { return false, nil } @@ -2291,7 +2269,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // the selected port. e.ID = id e.isPortReserved = true - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil @@ -2302,7 +2280,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.isRegistered = true e.setEndpointState(StateConnecting) - e.route = r.Clone() + r.Acquire() + e.route = r e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -2643,7 +2622,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { id := e.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2654,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2663,7 +2643,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. e.boundNICID = nic @@ -2727,6 +2707,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + // Linux passes the payload with the TCP header. We don't know if the TCP + // header even exists, it may not for fragmented packets. + Payload: pkt.Data.ToView(), + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.notifyProtocolGoroutine(notifyError) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { @@ -2741,16 +2756,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNoRoute - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNetworkUnreachable - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } @@ -3008,6 +3017,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { Ssthresh: e.snd.sndSsthresh, SndCAAckCount: e.snd.sndCAAckCount, Outstanding: e.snd.outstanding, + SackedOut: e.snd.sackedOut, SndWnd: e.snd.sndWnd, SndUna: e.snd.sndUna, SndNxt: e.snd.sndNxt, diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index f2b1b68da..405a6dce7 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // If we started off with a window larger than what can he held in // the 16bit window field, we ceil the value to the max value. - // While ceiling, we still do not want to grow the right edge when - // not applicable. if scaledWnd > math.MaxUint16 { - if toGrow { - scaledWnd = seqnum.Size(math.MaxUint16) - } else { - scaledWnd = seqnum.Size(uint16(scaledWnd)) - } + scaledWnd = seqnum.Size(math.MaxUint16) + + // Ensure that the stashed receive window always reflects what + // is being advertised. + r.rcvWnd = scaledWnd << r.rcvWndScale } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index baec762e1..cc991aba6 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -137,6 +137,9 @@ type sender struct { // that have been sent but not yet acknowledged. outstanding int + // sackedOut is the number of packets which are selectively acked. + sackedOut int + // sndWnd is the send window size. sndWnd seqnum.Size @@ -372,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } + oldMSS := s.maxPayloadSize s.maxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) @@ -394,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // Rewind writeNext to the first segment exceeding the MTU. Do nothing // if it is already before such a packet. + nextSeg := s.writeNext for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { if seg == s.writeNext { // We got to writeNext before we could find a segment @@ -401,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { break } - if seg.data.Size() > m { + if nextSeg == s.writeNext && seg.data.Size() > m { // We found a segment exceeding the MTU. Rewind // writeNext and try to retransmit it. - s.writeNext = seg - break + nextSeg = seg + } + + if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Update sackedOut for new maximum payload size. + s.sackedOut -= s.pCount(seg, oldMSS) + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } } // Since we likely reduced the number of outstanding packets, we may be // ready to send some more. + s.writeNext = nextSeg s.sendData() } @@ -629,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool { // pCount returns the number of packets in the segment. Due to GSO, a segment // can be composed of multiple packets. -func (s *sender) pCount(seg *segment) int { +func (s *sender) pCount(seg *segment, maxPayloadSize int) int { size := seg.data.Size() if size == 0 { return 1 } - return (size-1)/s.maxPayloadSize + 1 + return (size-1)/maxPayloadSize + 1 } // splitSeg splits a given segment at the size specified and inserts the @@ -1023,7 +1034,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding += s.pCount(seg) + s.outstanding += s.pCount(seg, s.maxPayloadSize) s.writeNext = seg.Next() } @@ -1038,6 +1049,7 @@ func (s *sender) enterRecovery() { // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. s.sndCwnd = s.sndSsthresh + 3 + s.sackedOut = 0 s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding @@ -1207,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg, s.ep.tsOffset) s.rc.detectReorder(seg) seg.acked = true + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } seg = seg.Next() } @@ -1380,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg) + prevCount := s.pCount(seg, s.maxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg) + s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) break } @@ -1399,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.writeList.Remove(seg) - // If SACK is enabled then Only reduce outstanding if + // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg) + s.outstanding -= s.pCount(seg, s.maxPayloadSize) + } else { + s.sackedOut -= s.pCount(seg, s.maxPayloadSize) } seg.decRef() ackLeft -= datalen diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index ef7f5719f..faf0c0ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) { expected++ } } + +// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK +// is received. +func TestSACKUpdateSackedOut(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + ackNum := 0 + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.SackedOut is what we expect. + if state.Sender.SackedOut != 2 && ackNum == 0 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) + } + + if state.Sender.SackedOut != 0 && ackNum == 1 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) + } + if ackNum > 0 { + close(probeDone) + } + ackNum++ + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + sendAndReceive(t, c, 8) + + // ACK for [3-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) + bytesRead := 2 * maxPayload + end := start.Add(seqnum.Size(bytesRead)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + bytesRead += 3 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 1759ebea9..cf60d5b53 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1380,9 +1380,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.Cleanup() c.Create(-1) - bindToDevice := tcpip.BindToDeviceOption(test.device) - if err := c.EP.SetSockOpt(&bindToDevice); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err) + if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) } // Start connection attempt. waitEntry, _ := waiter.NewChannelEntry(nil) @@ -1932,6 +1931,84 @@ func TestFullWindowReceive(t *testing.T) { ) } +// Test the stack receive window advertisement on receiving segments smaller than +// segment overhead. It tests for the right edge of the window to not grow when +// the endpoint is not being read from. +func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize, + Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), + } + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + + c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Bump up the receive buffer size such that, when the receive window grows, + // the scaled window exceeds maxUint16. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) + } + + // Keep the payload size < segment overhead and such that it is a multiple + // of the window scaled value. This enables the test to perform equality + // checks on the incoming receive window. + payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale)) + payloadLen := seqnum.Size(len(payload)) + iss := seqnum.Value(789) + seqNum := iss.Add(1) + + // Send payload to the endpoint and return the advertised receive window + // from the endpoint. + getIncomingRcvWnd := func() uint32 { + c.SendPacket(payload, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + Flags: header.TCPFlagAck, + RcvWnd: 30000, + }) + seqNum = seqNum.Add(payloadLen) + + pkt := c.GetPacket() + return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale + } + + // Read the advertised receive window with the ACK for payload. + rcvWnd := getIncomingRcvWnd() + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Read the data so that the subsequent ACK from the endpoint + // grows the right edge of the window. + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("got Read(nil) = %s", err) + } + + // Check if we have received max uint16 as our advertised + // scaled window now after a read above. + maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) + if got, want := getIncomingRcvWnd(), maxRcv; got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } +} + func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4148,7 +4225,7 @@ func TestReadAfterClosedState(t *testing.T) { // Check that peek works. peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) + n, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { t.Fatalf("Peek failed: %s", err) } @@ -4174,7 +4251,7 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { + if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -4429,7 +4506,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -4439,15 +4516,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) } }) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 010a23e45..ee55f030c 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -635,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + TransportProtocol: tcp.ProtocolNumber, + HopLimit: 65, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5043e7aa5..9b9e4deb0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -30,10 +30,11 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry - senderAddress tcpip.FullAddress - packetInfo tcpip.IPPacketInfo - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + senderAddress tcpip.FullAddress + destinationAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 // tos stores either the receiveTOS or receiveTClass value. tos uint8 } @@ -108,7 +109,6 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID portFlags ports.Flags - bindToDevice tcpip.NICID lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -143,9 +143,6 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -228,6 +225,13 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -323,6 +327,10 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } + if e.ops.GetReceiveOriginalDstAddress() { + cm.HasOriginalDstAddress = true + cm.OriginalDstAddress = p.destinationAddress + } return p.data.ToView(), cm, nil } @@ -509,6 +517,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. + so := e.SocketOptions() + if so.GetRecvError() { + so.QueueLocalErr( + tcpip.ErrMessageTooLong, + route.NetProto, + header.UDPMaximumPacketSize, + tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + Port: dstPort, + }, + v, + ) + } return 0, nil, tcpip.ErrMessageTooLong } @@ -545,8 +567,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. @@ -636,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { @@ -752,22 +778,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { delete(e.multicastMemberships, memToRemove) - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.mu.Lock() - e.bindToDevice = id - e.mu.Unlock() - case *tcpip.SocketDetachFilterOption: return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() } return nil } @@ -841,16 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } e.mu.Unlock() - case *tcpip.BindToDeviceOption: - e.mu.RLock() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() - - case *tcpip.LingerOption: - e.mu.RLock() - *o = e.linger - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -1004,7 +1006,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: e.ID.LocalAddress, @@ -1032,6 +1033,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } @@ -1042,7 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.ID = id e.boundBindToDevice = btd - e.route = r.Clone() + e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID e.effectiveNetProtos = netProtos @@ -1100,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { - return id, e.bindToDevice, err + return id, bindToDevice, err } id.LocalPort = port } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - return id, e.bindToDevice, err + return id, bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -1311,6 +1314,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB Addr: id.RemoteAddress, Port: header.UDP(hdr).SourcePort(), }, + destinationAddress: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: header.UDP(hdr).DestinationPort(), + }, } packet.data = pkt.Data e.rcvList.PushBack(packet) @@ -1341,15 +1349,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + // Linux passes the payload without the UDP header. + var payload []byte + udp := header.UDP(pkt.Data.ToView()) + if len(udp) >= header.UDPMinimumSize { + payload = udp.Payload() + } + + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + Payload: payload, + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.waiterQueue.Notify(waiter.EventErr) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { if e.EndpointState() == StateConnected { - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrConnectionRefused - e.lastErrorMu.Unlock() - - e.waiterQueue.Notify(waiter.EventErr) + var errType byte + var errCode byte + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + errType = byte(header.ICMPv4DstUnreachable) + errCode = byte(header.ICMPv4PortUnreachable) + case header.IPv6ProtocolNumber: + errType = byte(header.ICMPv6DstUnreachable) + errCode = byte(header.ICMPv6PortUnreachable) + default: + panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) + } + e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt) return } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 14e4648cd..d7fc21f11 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index e384f52dd..8429f34b4 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -452,12 +452,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -554,7 +554,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -564,15 +564,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) } }) @@ -1427,6 +1425,93 @@ func TestReadIPPacketInfo(t *testing.T) { } } +func TestReadRecvOriginalDstAddr(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedOriginalDstAddr tcpip.FullAddress + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%#v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) + } + } + + c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) + + testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1994,12 +2079,12 @@ func TestShortHeader(t *testing.T) { // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 2bf0a22ff..7b5fcef9c 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -55,11 +55,8 @@ type Container struct { copyErr error cleanups []func() - // Profiles are profiles added to this container. They contain methods - // that are run after Creation, Start, and Cleanup of this Container, along - // a handle to restart the profile. Generally, tests/benchmarks using - // profiles need to run as root. - profiles []Profile + // profile is the profiling hook associated with this container. + profile *profile } // RunOpts are options for running a container. @@ -105,22 +102,7 @@ type RunOpts struct { Links []string } -// MakeContainer sets up the struct for a Docker container. -// -// Names of containers will be unique. -// Containers will check flags for profiling requests. -func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { - c := MakeNativeContainer(ctx, logger) - c.runtime = *runtime - if p := MakePprofFromFlags(c); p != nil { - c.AddProfile(p) - } - return c -} - -// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native -// containers aren't profiled. -func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { +func makeContainer(ctx context.Context, logger testutil.Logger, runtime string) *Container { // Slashes are not allowed in container names. name := testutil.RandomID(logger.Name()) name = strings.ReplaceAll(name, "/", "-") @@ -132,24 +114,29 @@ func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container return &Container{ logger: logger, Name: name, - runtime: "", + runtime: runtime, client: client, } } -// AddProfile adds a profile to this container. -func (c *Container) AddProfile(p Profile) { - c.profiles = append(c.profiles, p) +// MakeContainer constructs a suitable Container object. +// +// The runtime used is determined by the runtime flag. +// +// Containers will check flags for profiling requests. +func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + c := makeContainer(ctx, logger, *runtime) + c.profileInit() + return c } -// RestartProfiles calls Restart on all profiles for this container. -func (c *Container) RestartProfiles() error { - for _, profile := range c.profiles { - if err := profile.Restart(c); err != nil { - return err - } - } - return nil +// MakeNativeContainer constructs a suitable Container object. +// +// The runtime used will be the system default. +// +// Native containers aren't profiled. +func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { + return makeContainer(ctx, logger, "" /*runtime*/) } // Spawn is analogous to 'docker run -d'. @@ -206,6 +193,8 @@ func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, return "", err } + c.stopProfiling() + return c.Logs(ctx) } @@ -236,11 +225,6 @@ func (c *Container) create(ctx context.Context, conf *container.Config, hostconf return err } c.id = cont.ID - for _, profile := range c.profiles { - if err := profile.OnCreate(c); err != nil { - return fmt.Errorf("OnCreate method failed with: %v", err) - } - } return nil } @@ -286,11 +270,13 @@ func (c *Container) Start(ctx context.Context) error { if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { return fmt.Errorf("ContainerStart failed: %v", err) } - for _, profile := range c.profiles { - if err := profile.OnStart(c); err != nil { - return fmt.Errorf("OnStart method failed: %v", err) + + if c.profile != nil { + if err := c.profile.Start(c); err != nil { + c.logger.Logf("profile.Start failed: %v", err) } } + return nil } @@ -499,8 +485,18 @@ func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, t } } +// stopProfiling stops profiling. +func (c *Container) stopProfiling() { + if c.profile != nil { + if err := c.profile.Stop(c); err != nil { + c.logger.Logf("profile.Stop failed: %v", err) + } + } +} + // Kill kills the container. func (c *Container) Kill(ctx context.Context) error { + c.stopProfiling() return c.client.ContainerKill(ctx, c.id, "") } @@ -517,14 +513,6 @@ func (c *Container) Remove(ctx context.Context) error { // CleanUp kills and deletes the container (best effort). func (c *Container) CleanUp(ctx context.Context) { - // Execute profile cleanups before the container goes down. - for _, profile := range c.profiles { - profile.OnCleanUp(c) - } - - // Forget profiles. - c.profiles = nil - // Execute all cleanups. We execute cleanups here to close any // open connections to the container before closing. Open connections // can cause Kill and Remove to hang. @@ -538,10 +526,12 @@ func (c *Container) CleanUp(ctx context.Context) { // Just log; can't do anything here. c.logger.Logf("error killing container %q: %v", c.Name, err) } + // Remove the image. if err := c.Remove(ctx); err != nil { c.logger.Logf("error removing container %q: %v", c.Name, err) } + // Forget all mounts. c.mounts = nil } diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 7027df1a5..a40005799 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -49,15 +49,11 @@ var ( // pprofBaseDir allows the user to change the directory to which profiles are // written. By default, profiles will appear under: // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. - pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") - - // duration is the max duration `runsc debug` will run and capture profiles. - // If the container's clean up method is called prior to duration, the - // profiling process will be killed. - duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds") + pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") + pprofDuration = flag.Duration("pprof-duration", time.Hour, "profiling duration (automatically stopped at container exit)") // The below flags enable each type of profile. Multiple profiles can be - // enabled for each run. + // enabled for each run. The profile will be collected from the start. pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go index 55f9496cd..f1103eb6e 100644 --- a/pkg/test/dockerutil/profile.go +++ b/pkg/test/dockerutil/profile.go @@ -17,72 +17,57 @@ package dockerutil import ( "context" "fmt" - "io" "os" "os/exec" "path/filepath" + "syscall" "time" ) -// Profile represents profile-like operations on a container, -// such as running perf or pprof. It is meant to be added to containers -// such that the container type calls the Profile during its lifecycle. -type Profile interface { - // OnCreate is called just after the container is created when the container - // has a valid ID (e.g. c.ID()). - OnCreate(c *Container) error - - // OnStart is called just after the container is started when the container - // has a valid Pid (e.g. c.SandboxPid()). - OnStart(c *Container) error - - // Restart restarts the Profile on request. - Restart(c *Container) error - - // OnCleanUp is called during the container's cleanup method. - // Cleanups should just log errors if they have them. - OnCleanUp(c *Container) error -} - -// Pprof is for running profiles with 'runsc debug'. Pprof workloads -// should be run as root and ONLY against runsc sandboxes. The runtime -// should have --profile set as an option in /etc/docker/daemon.json in -// order for profiling to work with Pprof. -type Pprof struct { - BasePath string // path to put profiles - BlockProfile bool - CPUProfile bool - HeapProfile bool - MutexProfile bool - Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. - shouldRun bool - cmd *exec.Cmd - stdout io.ReadCloser - stderr io.ReadCloser +// profile represents profile-like operations on a container. +// +// It is meant to be added to containers such that the container type calls +// the profile during its lifecycle. Standard implementations are below. + +// profile is for running profiles with 'runsc debug'. +type profile struct { + BasePath string + Types []string + Duration time.Duration + cmd *exec.Cmd } -// MakePprofFromFlags makes a Pprof profile from flags. -func MakePprofFromFlags(c *Container) *Pprof { - if !(*pprofBlock || *pprofCPU || *pprofHeap || *pprofMutex) { - return nil +// profileInit initializes a profile object, if required. +func (c *Container) profileInit() { + if !*pprofBlock && !*pprofCPU && !*pprofMutex && !*pprofHeap { + return // Nothing to do. + } + c.profile = &profile{ + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + Duration: *pprofDuration, } - return &Pprof{ - BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), - BlockProfile: *pprofBlock, - CPUProfile: *pprofCPU, - HeapProfile: *pprofHeap, - MutexProfile: *pprofMutex, - Duration: *duration, + if *pprofCPU { + c.profile.Types = append(c.profile.Types, "cpu") + } + if *pprofHeap { + c.profile.Types = append(c.profile.Types, "heap") + } + if *pprofMutex { + c.profile.Types = append(c.profile.Types, "mutex") + } + if *pprofBlock { + c.profile.Types = append(c.profile.Types, "block") } } -// OnCreate implements Profile.OnCreate. -func (p *Pprof) OnCreate(c *Container) error { - return os.MkdirAll(p.BasePath, 0755) -} +// createProcess creates the collection process. +func (p *profile) createProcess(c *Container) error { + // Ensure our directory exists. + if err := os.MkdirAll(p.BasePath, 0755); err != nil { + return err + } -// OnStart implements Profile.OnStart. -func (p *Pprof) OnStart(c *Container) error { + // Find the runtime to invoke. path, err := RuntimePath() if err != nil { return fmt.Errorf("failed to get runtime path: %v", err) @@ -90,58 +75,66 @@ func (p *Pprof) OnStart(c *Container) error { // The root directory of this container's runtime. root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) - // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`. + + // Format is `runsc --root=rootdir debug --profile-*=file --duration=24h containerID`. args := []string{root, "debug"} - args = append(args, p.makeProfileArgs(c)...) + for _, profileArg := range p.Types { + outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg)) + args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath)) + } + args = append(args, fmt.Sprintf("--duration=%s", p.Duration)) // Or until container exits. args = append(args, c.ID()) // Best effort wait until container is running. for now := time.Now(); time.Since(now) < 5*time.Second; { if status, err := c.Status(context.Background()); err != nil { return fmt.Errorf("failed to get status with: %v", err) - } else if status.Running { break } - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) } p.cmd = exec.Command(path, args...) + p.cmd.Stderr = os.Stderr // Pass through errors. if err := p.cmd.Start(); err != nil { - return fmt.Errorf("process failed: %v", err) + return fmt.Errorf("start process failed: %v", err) } + return nil } -// Restart implements Profile.Restart. -func (p *Pprof) Restart(c *Container) error { - p.OnCleanUp(c) - return p.OnStart(c) +// killProcess kills the process, if running. +// +// Precondition: mu must be held. +func (p *profile) killProcess() error { + if p.cmd != nil && p.cmd.Process != nil { + return p.cmd.Process.Signal(syscall.SIGTERM) + } + return nil } -// OnCleanUp implements Profile.OnCleanup -func (p *Pprof) OnCleanUp(c *Container) error { +// waitProcess waits for the process, if running. +// +// Precondition: mu must be held. +func (p *profile) waitProcess() error { defer func() { p.cmd = nil }() - if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() { - return p.cmd.Process.Kill() + if p.cmd != nil { + return p.cmd.Wait() } return nil } -// makeProfileArgs turns Pprof fields into runsc debug flags. -func (p *Pprof) makeProfileArgs(c *Container) []string { - var ret []string - if p.BlockProfile { - ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof"))) - } - if p.CPUProfile { - ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) - } - if p.HeapProfile { - ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) - } - if p.MutexProfile { - ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof"))) +// Start is called when profiling is started. +func (p *profile) Start(c *Container) error { + return p.createProcess(c) +} + +// Stop is called when profiling is started. +func (p *profile) Stop(c *Container) error { + killErr := p.killProcess() + waitErr := p.waitProcess() + if waitErr != nil && killErr != nil { + return killErr } - ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration)) - return ret + return waitErr // Ignore okay wait, err kill. } diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go index 8c4ffe483..4fe9ce15c 100644 --- a/pkg/test/dockerutil/profile_test.go +++ b/pkg/test/dockerutil/profile_test.go @@ -17,6 +17,7 @@ package dockerutil import ( "context" "fmt" + "io/ioutil" "os" "path/filepath" "testing" @@ -25,52 +26,60 @@ import ( type testCase struct { name string - pprof Pprof + profile profile expectedFiles []string } -func TestPprof(t *testing.T) { +func TestProfile(t *testing.T) { // Basepath and expected file names for each type of profile. - basePath := "/tmp/test/profile" + tmpDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("unable to create temporary directory: %v", err) + } + defer os.RemoveAll(tmpDir) + + // All expected names. + basePath := tmpDir block := "block.pprof" cpu := "cpu.pprof" - goprofle := "go.pprof" heap := "heap.pprof" mutex := "mutex.pprof" testCases := []testCase{ { - name: "Cpu", - pprof: Pprof{ - BasePath: basePath, - CPUProfile: true, - Duration: 2 * time.Second, + name: "One", + profile: profile{ + BasePath: basePath, + Types: []string{"cpu"}, + Duration: 2 * time.Second, }, expectedFiles: []string{cpu}, }, { name: "All", - pprof: Pprof{ - BasePath: basePath, - BlockProfile: true, - CPUProfile: true, - HeapProfile: true, - MutexProfile: true, - Duration: 2 * time.Second, + profile: profile{ + BasePath: basePath, + Types: []string{"block", "cpu", "heap", "mutex"}, + Duration: 2 * time.Second, }, - expectedFiles: []string{block, cpu, goprofle, heap, mutex}, + expectedFiles: []string{block, cpu, heap, mutex}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() c := MakeContainer(ctx, t) + // Set basepath to include the container name so there are no conflicts. - tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name) - c.AddProfile(&tc.pprof) + localProfile := tc.profile // Copy it. + localProfile.BasePath = filepath.Join(localProfile.BasePath, tc.name) + + // Set directly on the container, to avoid flags. + c.profile = &localProfile func() { defer c.CleanUp(ctx) + // Start a container. if err := c.Spawn(ctx, RunOpts{ Image: "basic/alpine", @@ -83,24 +92,24 @@ func TestPprof(t *testing.T) { } // End early if the expected files exist and have data. - for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) { - if err := checkFiles(tc); err == nil { + for start := time.Now(); time.Since(start) < localProfile.Duration; time.Sleep(100 * time.Millisecond) { + if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err == nil { break } } }() // Check all expected files exist and have data. - if err := checkFiles(tc); err != nil { + if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err != nil { t.Fatalf(err.Error()) } }) } } -func checkFiles(tc testCase) error { - for _, file := range tc.expectedFiles { - stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file)) +func checkFiles(basePath string, expectedFiles []string) error { + for _, file := range expectedFiles { + stat, err := os.Stat(filepath.Join(basePath, file)) if err != nil { return fmt.Errorf("stat failed with: %v", err) } else if stat.Size() < 1 { diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index 13b2ea314..dfd23032c 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -283,12 +283,10 @@ func (s *Server) handleOne(client *unet.Socket) error { // Client is dead. return err } + if s.afterRPCCallback != nil { + defer s.afterRPCCallback() + } - defer func() { - if s.afterRPCCallback != nil { - s.afterRPCCallback() - } - }() // Explicitly close all these files after the call. // // This is also explicitly a reference to the files after the call, diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 9b1e7a085..79db8895b 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -167,7 +167,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) { return n, err } -// Writer implements io.Writer.Write. +// Write implements io.Writer.Write. func (rw *IOReadWriter) Write(src []byte) (int, error) { n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts) end, ok := rw.Addr.AddLength(uint64(n)) |