summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/ioctl.go20
-rw-r--r--pkg/abi/linux/membarrier.go34
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go13
-rw-r--r--pkg/abi/linux/seccomp.go19
-rw-r--r--pkg/abi/linux/signalfd.go4
-rw-r--r--pkg/merkletree/merkletree.go224
-rw-r--r--pkg/merkletree/merkletree_test.go175
-rw-r--r--pkg/seccomp/BUILD2
-rw-r--r--pkg/seccomp/seccomp_test.go246
-rw-r--r--pkg/sentry/arch/arch_aarch64.go2
-rw-r--r--pkg/sentry/arch/registers.proto1
-rw-r--r--pkg/sentry/devices/tundev/BUILD1
-rw-r--r--pkg/sentry/devices/tundev/tundev.go14
-rw-r--r--pkg/sentry/fs/dev/BUILD1
-rw-r--r--pkg/sentry/fs/dev/net_tun.go52
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go13
-rw-r--r--pkg/sentry/fsbridge/vfs.go2
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go32
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go46
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/block_group_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_new.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_old.go4
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/dirent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/disklayout.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go12
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode.go3
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_new.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/inode_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock.go6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_32.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_64.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_old.go2
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/superblock_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/test_utils.go6
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go5
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go19
-rw-r--r--pkg/sentry/fsimpl/ext/utils.go8
-rw-r--r--pkg/sentry/fsimpl/host/socket.go2
-rw-r--r--pkg/sentry/fsimpl/signalfd/BUILD1
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go14
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go2
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD21
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go175
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go173
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go490
-rw-r--r--pkg/sentry/hostmm/BUILD3
-rw-r--r--pkg/sentry/hostmm/membarrier.go90
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/kcov.go40
-rw-r--r--pkg/sentry/kernel/kernel.go15
-rw-r--r--pkg/sentry/kernel/seccomp.go46
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD1
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go14
-rw-r--r--pkg/sentry/kernel/task.go2
-rw-r--r--pkg/sentry/kernel/threads.go7
-rw-r--r--pkg/sentry/kernel/vdso.go19
-rw-r--r--pkg/sentry/memmap/memmap.go4
-rw-r--r--pkg/sentry/mm/mm.go14
-rw-r--r--pkg/sentry/mm/pma.go24
-rw-r--r--pkg/sentry/mm/syscalls.go25
-rw-r--r--pkg/sentry/mm/vma.go10
-rw-r--r--pkg/sentry/platform/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/BUILD12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.s7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go31
-rw-r--r--pkg/sentry/platform/kvm/filters_amd64.go13
-rw-r--r--pkg/sentry/platform/kvm/filters_arm64.go11
-rw-r--r--pkg/sentry/platform/kvm/kvm.go13
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go21
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go17
-rw-r--r--pkg/sentry/platform/kvm/machine.go21
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go173
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go115
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go13
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go36
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go26
-rw-r--r--pkg/sentry/platform/platform.go51
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go1
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go38
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go7
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.s204
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s19
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD5
-rw-r--r--pkg/sentry/platform/ring0/kernel.go22
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go64
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go6
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/lib_amd64.s47
-rw-r--r--pkg/sentry/platform/ring0/offsets_amd64.go11
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go4
-rw-r--r--pkg/sentry/platform/ring0/x86.go40
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go4
-rw-r--r--pkg/sentry/socket/netfilter/targets.go85
-rw-r--r--pkg/sentry/socket/netstack/netstack.go33
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go1
-rw-r--r--pkg/sentry/syscalls/linux/sys_membarrier.go103
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go12
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go6
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/mount.go7
-rw-r--r--pkg/sentry/vfs/vfs.go7
-rw-r--r--pkg/tcpip/buffer/view.go18
-rw-r--r--pkg/tcpip/checker/checker.go202
-rw-r--r--pkg/tcpip/header/eth.go16
-rw-r--r--pkg/tcpip/header/eth_test.go47
-rw-r--r--pkg/tcpip/header/icmpv4.go9
-rw-r--r--pkg/tcpip/header/icmpv6.go16
-rw-r--r--pkg/tcpip/header/ipv4.go71
-rw-r--r--pkg/tcpip/header/ipv6.go14
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go113
-rw-r--r--pkg/tcpip/header/ipversion_test.go2
-rw-r--r--pkg/tcpip/link/pipe/BUILD15
-rw-r--r--pkg/tcpip/link/pipe/pipe.go124
-rw-r--r--pkg/tcpip/link/tun/device.go42
-rw-r--r--pkg/tcpip/network/arp/arp.go25
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD4
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go81
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go124
-rw-r--r--pkg/tcpip/network/ip_test.go118
-rw-r--r--pkg/tcpip/network/ipv4/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go146
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go221
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go577
-rw-r--r--pkg/tcpip/network/ipv6/BUILD2
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go133
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go264
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go314
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go650
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go2
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go8
-rw-r--r--pkg/tcpip/network/testutil/BUILD1
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go5
-rw-r--r--pkg/tcpip/stack/conntrack.go42
-rw-r--r--pkg/tcpip/stack/forwarding_test.go (renamed from pkg/tcpip/stack/forwarder_test.go)12
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go40
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go15
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go66
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go4
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go5
-rw-r--r--pkg/tcpip/stack/nic.go147
-rw-r--r--pkg/tcpip/stack/nic_test.go10
-rw-r--r--pkg/tcpip/stack/packet_buffer.go15
-rw-r--r--pkg/tcpip/stack/pending_packets.go (renamed from pkg/tcpip/stack/forwarder.go)60
-rw-r--r--pkg/tcpip/stack/registration.go41
-rw-r--r--pkg/tcpip/stack/route.go71
-rw-r--r--pkg/tcpip/stack/stack.go48
-rw-r--r--pkg/tcpip/stack/stack_test.go62
-rw-r--r--pkg/tcpip/tests/integration/BUILD4
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go378
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go219
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go2
-rw-r--r--pkg/tcpip/transport/tcp/connect.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go44
-rw-r--r--pkg/tcpip/transport/tcp/rack.go54
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go54
-rw-r--r--pkg/tcpip/transport/tcp/segment.go3
-rw-r--r--pkg/tcpip/transport/tcp/snd.go62
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go75
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go20
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go12
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go30
-rw-r--r--pkg/test/testutil/testutil.go2
178 files changed, 6812 insertions, 1721 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index cdcaa8c73..4a26e28de 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -38,6 +38,7 @@ go_library(
"ipc.go",
"limits.go",
"linux.go",
+ "membarrier.go",
"mm.go",
"netdevice.go",
"netfilter.go",
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index dc9ac7e7c..7df02dd6d 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -121,9 +121,27 @@ const (
// Constants from uapi/linux/fsverity.h.
const (
- FS_IOC_ENABLE_VERITY = 1082156677
+ FS_IOC_ENABLE_VERITY = 1082156677
+ FS_IOC_MEASURE_VERITY = 3221513862
)
+// DigestMetadata is a helper struct for VerityDigest.
+//
+// +marshal
+type DigestMetadata struct {
+ DigestAlgorithm uint16
+ DigestSize uint16
+}
+
+// SizeOfDigestMetadata is the size of struct DigestMetadata.
+const SizeOfDigestMetadata = 4
+
+// VerityDigest is struct from uapi/linux/fsverity.h.
+type VerityDigest struct {
+ Metadata DigestMetadata
+ Digest []byte
+}
+
// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
func IOC(dir, typ, nr, size uint32) uint32 {
return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
diff --git a/pkg/abi/linux/membarrier.go b/pkg/abi/linux/membarrier.go
new file mode 100644
index 000000000..4f6021a1d
--- /dev/null
+++ b/pkg/abi/linux/membarrier.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// membarrier(2) commands, from include/uapi/linux/membarrier.h.
+const (
+ MEMBARRIER_CMD_QUERY = 0
+ MEMBARRIER_CMD_GLOBAL = (1 << 0)
+ MEMBARRIER_CMD_GLOBAL_EXPEDITED = (1 << 1)
+ MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED = (1 << 2)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED = (1 << 3)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED = (1 << 4)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 5)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 6)
+ MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ = (1 << 7)
+ MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ = (1 << 8)
+)
+
+// membarrier(2) flags, from include/uapi/linux/membarrier.h.
+const (
+ MEMBARRIER_CMD_FLAG_CPU = (1 << 0)
+)
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index a137940b6..6d31eb5e3 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -321,3 +321,16 @@ const (
// Enable all flags.
IP6T_INV_MASK = 0x7F
)
+
+// NFNATRange corresponds to struct nf_nat_range in
+// include/uapi/linux/netfilter/nf_nat.h.
+type NFNATRange struct {
+ Flags uint32
+ MinAddr Inet6Addr
+ MaxAddr Inet6Addr
+ MinProto uint16 // Network byte order.
+ MaxProto uint16 // Network byte order.
+}
+
+// SizeOfNFNATRange is the size of NFNATRange.
+const SizeOfNFNATRange = 40
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
index b07cafe12..5be3f10f9 100644
--- a/pkg/abi/linux/seccomp.go
+++ b/pkg/abi/linux/seccomp.go
@@ -83,3 +83,22 @@ type SockFprog struct {
pad [6]byte
Filter *BPFInstruction
}
+
+// SeccompData is equivalent to struct seccomp_data, which contains the data
+// passed to seccomp-bpf filters.
+//
+// +marshal
+type SeccompData struct {
+ // Nr is the system call number.
+ Nr int32
+
+ // Arch is an AUDIT_ARCH_* value indicating the system call convention.
+ Arch uint32
+
+ // InstructionPointer is the value of the instruction pointer at the time
+ // of the system call.
+ InstructionPointer uint64
+
+ // Args contains the first 6 system call arguments.
+ Args [6]uint64
+}
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
index 85fad9956..468c6a387 100644
--- a/pkg/abi/linux/signalfd.go
+++ b/pkg/abi/linux/signalfd.go
@@ -23,6 +23,8 @@ const (
)
// SignalfdSiginfo is the siginfo encoding for signalfds.
+//
+// +marshal
type SignalfdSiginfo struct {
Signo uint32
Errno int32
@@ -41,5 +43,5 @@ type SignalfdSiginfo struct {
STime uint64
Addr uint64
AddrLSB uint16
- _ [48]uint8
+ _ [48]uint8 `marshal:"unaligned"`
}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 4b4f9bd52..a6e698f57 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -41,7 +41,7 @@ type Layout struct {
blockSize int64
// digestSize is the size of a generated hash.
digestSize int64
- // levelOffset contains the offset of the begnning of each level in
+ // levelOffset contains the offset of the beginning of each level in
// bytes. The number of levels in the tree is the length of the slice.
// The leaf nodes (level 0) contain hashes of blocks of the input data.
// Each level N contains hashes of the blocks in level N-1. The highest
@@ -123,45 +123,87 @@ func (layout Layout) blockOffset(level int, index int64) int64 {
return layout.levelOffset[level] + index*layout.blockSize
}
+// VerityDescriptor is a struct that is serialized and hashed to get a file's
+// root hash, which contains the root hash of the raw content and the file's
+// meatadata.
+type VerityDescriptor struct {
+ Name string
+ Mode uint32
+ UID uint32
+ GID uint32
+ RootHash []byte
+}
+
+func (d *VerityDescriptor) String() string {
+ return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash)
+}
+
+// GenerateParams contains the parameters used to generate a Merkle tree.
+type GenerateParams struct {
+ // File is a reader of the file to be hashed.
+ File io.ReadSeeker
+ // Size is the size of the file.
+ Size int64
+ // Name is the name of the target file.
+ Name string
+ // Mode is the mode of the target file.
+ Mode uint32
+ // UID is the user ID of the target file.
+ UID uint32
+ // GID is the group ID of the target file.
+ GID uint32
+ // TreeReader is a reader for the Merkle tree.
+ TreeReader io.ReadSeeker
+ // TreeWriter is a writer for the Merkle tree.
+ TreeWriter io.WriteSeeker
+ // DataAndTreeInSameFile is true if data and Merkle tree are in the same
+ // file, or false if Merkle tree is a separate file from data.
+ DataAndTreeInSameFile bool
+}
+
// Generate constructs a Merkle tree for the contents of data. The output is
// written to treeWriter. The treeReader should be able to read the tree after
// it has been written. That is, treeWriter and treeReader should point to the
// same underlying data but have separate cursors.
+//
+// Generate returns a hash of descriptor. The descriptor contains the file
+// metadata and the hash from file content.
+//
// Generate will modify the cursor for data, but always restores it to its
// original position upon exit. The cursor for tree is modified and not
// restored.
-func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, treeWriter io.WriteSeeker, dataAndTreeInSameFile bool) ([]byte, error) {
- layout := InitLayout(dataSize, dataAndTreeInSameFile)
+func Generate(params *GenerateParams) ([]byte, error) {
+ layout := InitLayout(params.Size, params.DataAndTreeInSameFile)
- numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize
+ numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize
// If the data is in the same file as the tree, zero pad the last data
// block.
- bytesInLastBlock := dataSize % layout.blockSize
- if dataAndTreeInSameFile && bytesInLastBlock != 0 {
+ bytesInLastBlock := params.Size % layout.blockSize
+ if params.DataAndTreeInSameFile && bytesInLastBlock != 0 {
zeroBuf := make([]byte, layout.blockSize-bytesInLastBlock)
- if _, err := treeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF {
+ if _, err := params.TreeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF {
return nil, err
}
- if _, err := treeWriter.Write(zeroBuf); err != nil {
+ if _, err := params.TreeWriter.Write(zeroBuf); err != nil {
return nil, err
}
}
// Store the current offset, so we can set it back once verification
// finishes.
- origOffset, err := data.Seek(0, io.SeekCurrent)
+ origOffset, err := params.File.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
- defer data.Seek(origOffset, io.SeekStart)
+ defer params.File.Seek(origOffset, io.SeekStart)
// Read from the beginning of both data and treeReader.
- if _, err := data.Seek(0, io.SeekStart); err != nil && err != io.EOF {
+ if _, err := params.File.Seek(0, io.SeekStart); err != nil && err != io.EOF {
return nil, err
}
- if _, err := treeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF {
+ if _, err := params.TreeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF {
return nil, err
}
@@ -176,11 +218,11 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
if level == 0 {
// Read data block from the target file since level 0 includes hashes
// of blocks in the input data.
- n, err = data.Read(buf)
+ n, err = params.File.Read(buf)
} else {
// Read data block from the tree file since levels higher than 0 are
// hashing the lower level hashes.
- n, err = treeReader.Read(buf)
+ n, err = params.TreeReader.Read(buf)
}
// err is populated as long as the bytes read is smaller than the buffer
@@ -200,7 +242,7 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
}
// Write the generated hash to the end of the tree file.
- if _, err = treeWriter.Write(digest[:]); err != nil {
+ if _, err = params.TreeWriter.Write(digest[:]); err != nil {
return nil, err
}
}
@@ -208,45 +250,124 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree
// remaining of the last block. But no need to do so for root.
if level != layout.rootLevel() && numBlocks%layout.hashesPerBlock() != 0 {
zeroBuf := make([]byte, layout.blockSize-(numBlocks%layout.hashesPerBlock())*layout.digestSize)
- if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ if _, err := params.TreeWriter.Write(zeroBuf[:]); err != nil {
return nil, err
}
}
numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock()
}
- return root, nil
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ RootHash: root,
+ }
+ ret := sha256.Sum256([]byte(descriptor.String()))
+ return ret[:], nil
+}
+
+// VerifyParams contains the params used to verify a portion of a file against
+// a Merkle tree.
+type VerifyParams struct {
+ // Out will be filled with verified data.
+ Out io.Writer
+ // File is a handler on the file to be verified.
+ File io.ReadSeeker
+ // tree is a handler on the Merkle tree used to verify file.
+ Tree io.ReadSeeker
+ // Size is the size of the file.
+ Size int64
+ // Name is the name of the target file.
+ Name string
+ // Mode is the mode of the target file.
+ Mode uint32
+ // UID is the user ID of the target file.
+ UID uint32
+ // GID is the group ID of the target file.
+ GID uint32
+ // ReadOffset is the offset of the data range to be verified.
+ ReadOffset int64
+ // ReadSize is the size of the data range to be verified.
+ ReadSize int64
+ // ExpectedRoot is a trusted hash for the file. It is compared with the
+ // calculated root hash to verify the content.
+ ExpectedRoot []byte
+ // DataAndTreeInSameFile is true if data and Merkle tree are in the same
+ // file, or false if Merkle tree is a separate file from data.
+ DataAndTreeInSameFile bool
+}
+
+// verifyDescriptor generates a hash from descriptor, and compares it with
+// expectedRoot.
+func verifyDescriptor(descriptor VerityDescriptor, expectedRoot []byte) error {
+ h := sha256.Sum256([]byte(descriptor.String()))
+ if !bytes.Equal(h[:], expectedRoot) {
+ return fmt.Errorf("unexpected root hash")
+ }
+ return nil
+}
+
+// verifyMetadata verifies the metadata by hashing a descriptor that contains
+// the metadata and compare the generated hash with expectedRoot.
+//
+// For verifyMetadata, params.data is not needed. It only accesses params.tree
+// for the raw root hash.
+func verifyMetadata(params *VerifyParams, layout Layout) error {
+ if _, err := params.Tree.Seek(layout.blockOffset(layout.rootLevel(), 0 /* index */), io.SeekStart); err != nil {
+ return fmt.Errorf("failed to seek to root hash: %w", err)
+ }
+ root := make([]byte, layout.digestSize)
+ if _, err := params.Tree.Read(root); err != nil {
+ return fmt.Errorf("failed to read root hash: %w", err)
+ }
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ RootHash: root,
+ }
+ return verifyDescriptor(descriptor, params.ExpectedRoot)
}
// Verify verifies the content read from data with offset. The content is
// verified against tree. If content spans across multiple blocks, each block is
// verified. Verification fails if the hash of the data does not match the tree
// at any level, or if the final root hash does not match expectedRoot.
-// Once the data is verified, it will be written using w.
+// Once the data is verified, it will be written using params.Out.
+//
+// Verify checks for both target file content and metadata. If readSize is 0,
+// only metadata is checked.
+//
// Verify will modify the cursor for data, but always restores it to its
// original position upon exit. The cursor for tree is modified and not
// restored.
-func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) {
- if readSize <= 0 {
- return 0, fmt.Errorf("Unexpected read size: %d", readSize)
+func Verify(params *VerifyParams) (int64, error) {
+ if params.ReadSize < 0 {
+ return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize)
+ }
+ layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile)
+ if params.ReadSize == 0 {
+ return 0, verifyMetadata(params, layout)
}
- layout := InitLayout(int64(dataSize), dataAndTreeInSameFile)
// Calculate the index of blocks that includes the target range in input
// data.
- firstDataBlock := readOffset / layout.blockSize
- lastDataBlock := (readOffset + readSize - 1) / layout.blockSize
+ firstDataBlock := params.ReadOffset / layout.blockSize
+ lastDataBlock := (params.ReadOffset + params.ReadSize - 1) / layout.blockSize
// Store the current offset, so we can set it back once verification
// finishes.
- origOffset, err := data.Seek(0, io.SeekCurrent)
+ origOffset, err := params.File.Seek(0, io.SeekCurrent)
if err != nil {
- return 0, fmt.Errorf("Find current data offset failed: %v", err)
+ return 0, fmt.Errorf("find current data offset failed: %w", err)
}
- defer data.Seek(origOffset, io.SeekStart)
+ defer params.File.Seek(origOffset, io.SeekStart)
// Move to the first block that contains target data.
- if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil {
- return 0, fmt.Errorf("Seek to datablock start failed: %v", err)
+ if _, err := params.File.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil {
+ return 0, fmt.Errorf("seek to datablock start failed: %w", err)
}
buf := make([]byte, layout.blockSize)
@@ -255,7 +376,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
for i := firstDataBlock; i <= lastDataBlock; i++ {
// Read a block that includes all or part of target range in
// input data.
- bytesRead, err := data.Read(buf)
+ bytesRead, err := params.File.Read(buf)
readErr = err
// If at the end of input data and all previous blocks are
// verified, return the verified input data and EOF.
@@ -263,7 +384,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
break
}
if readErr != nil && readErr != io.EOF {
- return 0, fmt.Errorf("Read from data failed: %v", err)
+ return 0, fmt.Errorf("read from data failed: %w", err)
}
// If this is the end of file, zero the remaining bytes in buf,
// otherwise they are still from the previous block.
@@ -274,22 +395,29 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
buf[j] = 0
}
}
- if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil {
+ descriptor := VerityDescriptor{
+ Name: params.Name,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
+ }
+ if err := verifyBlock(params.Tree, descriptor, layout, buf, i, params.ExpectedRoot); err != nil {
return 0, err
}
+
// startOff is the beginning of the read range within the
// current data block. Note that for all blocks other than the
// first, startOff should be 0.
startOff := int64(0)
if i == firstDataBlock {
- startOff = readOffset % layout.blockSize
+ startOff = params.ReadOffset % layout.blockSize
}
// endOff is the end of the read range within the current data
// block. Note that for all blocks other than the last, endOff
// should be the block size.
endOff := layout.blockSize
if i == lastDataBlock {
- endOff = (readOffset+readSize-1)%layout.blockSize + 1
+ endOff = (params.ReadOffset+params.ReadSize-1)%layout.blockSize + 1
}
// If the provided size exceeds the end of input data, we should
// only copy the parts in buf that's part of input data.
@@ -299,7 +427,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
if endOff > int64(bytesRead) {
endOff = int64(bytesRead)
}
- n, err := w.Write(buf[startOff:endOff])
+ n, err := params.Out.Write(buf[startOff:endOff])
if err != nil {
return total, err
}
@@ -313,9 +441,11 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in
// original data. The block is verified through each level of the tree. It
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
-// expectedRoot. verifyBlock modifies the cursor for tree. Users needs to
-// maintain the cursor if intended.
-func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error {
+// expectedRoot.
+//
+// verifyBlock modifies the cursor for tree. Users needs to maintain the cursor
+// if intended.
+func verifyBlock(tree io.ReadSeeker, descriptor VerityDescriptor, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -352,21 +482,13 @@ func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex
}
if !bytes.Equal(digest, expectedDigest) {
- return fmt.Errorf("Verification failed")
- }
-
- // If this is the root layer, no need to generate next level
- // hash.
- if level == layout.rootLevel() {
- break
+ return fmt.Errorf("verification failed")
}
blockIndex = blockIndex / layout.hashesPerBlock()
}
- // Verification for the tree succeeded. Now compare the root hash in the
- // tree with expectedRoot.
- if !bytes.Equal(digest[:], expectedRoot) {
- return fmt.Errorf("Verification failed")
- }
- return nil
+ // Verification for the tree succeeded. Now hash the descriptor with
+ // the root hash and compare it with expectedRoot.
+ descriptor.RootHash = digest
+ return verifyDescriptor(descriptor, expectedRoot)
}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index daaca759a..bb11ec844 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -84,6 +84,13 @@ func TestLayout(t *testing.T) {
}
}
+const (
+ defaultName = "merkle_test"
+ defaultMode = 0644
+ defaultUID = 0
+ defaultGID = 0
+)
+
// bytesReadWriter is used to read from/write to/seek in a byte array. Unlike
// bytes.Buffer, it keeps the whole buffer during read so that it can be reused.
type bytesReadWriter struct {
@@ -138,19 +145,19 @@ func TestGenerate(t *testing.T) {
}{
{
data: bytes.Repeat([]byte{0}, usermem.PageSize),
- expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ expectedRoot: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
},
{
data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
- expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ expectedRoot: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
},
{
data: []byte{'a'},
- expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ expectedRoot: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
},
{
data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
- expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90},
+ expectedRoot: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
},
}
@@ -158,16 +165,25 @@ func TestGenerate(t *testing.T) {
t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ params := GenerateParams{
+ Size: int64(len(tc.data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
if dataAndTreeInSameFile {
tree.Write(tc.data)
- root, err = Generate(&tree, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
+ params.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ params.File = &bytesReadWriter{
bytes: tc.data,
- }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile)
+ }
}
+ root, err := Generate(&params)
if err != nil {
t.Fatalf("Got err: %v, want nil", err)
}
@@ -194,6 +210,10 @@ func TestVerify(t *testing.T) {
// modified byte falls in verification range, Verify should
// fail, otherwise Verify should still succeed.
modifyByte int64
+ modifyName bool
+ modifyMode bool
+ modifyUID bool
+ modifyGID bool
shouldSucceed bool
}{
// Verify range start outside the data range should fail.
@@ -222,12 +242,48 @@ func TestVerify(t *testing.T) {
modifyByte: 0,
shouldSucceed: false,
},
- // Invalid verify range (0 size) should fail.
+ // 0 verify size should only verify metadata.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ shouldSucceed: true,
+ },
+ // Modified name should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyName: true,
+ shouldSucceed: false,
+ },
+ // Modified mode should fail verification.
{
dataSize: usermem.PageSize,
verifyStart: 0,
verifySize: 0,
modifyByte: 0,
+ modifyMode: true,
+ shouldSucceed: false,
+ },
+ // Modified UID should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyUID: true,
+ shouldSucceed: false,
+ },
+ // Modified GID should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifyGID: true,
shouldSucceed: false,
},
// The test cases below use a block-aligned verify range.
@@ -316,16 +372,25 @@ func TestVerify(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
if dataAndTreeInSameFile {
tree.Write(data)
- root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile)
+ genParams.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ genParams.File = &bytesReadWriter{
bytes: data,
- }, int64(tc.dataSize), &tree, &tree, false /* dataAndTreeInSameFile */)
+ }
}
+ root, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -333,8 +398,34 @@ func TestVerify(t *testing.T) {
// Flip a bit in data and checks Verify results.
var buf bytes.Buffer
data[tc.modifyByte] ^= 1
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: tc.dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ ReadOffset: tc.verifyStart,
+ ReadSize: tc.verifySize,
+ ExpectedRoot: root,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+ if tc.modifyName {
+ verifyParams.Name = defaultName + "abc"
+ }
+ if tc.modifyMode {
+ verifyParams.Mode = defaultMode + 1
+ }
+ if tc.modifyUID {
+ verifyParams.UID = defaultUID + 1
+ }
+ if tc.modifyGID {
+ verifyParams.GID = defaultGID + 1
+ }
if tc.shouldSucceed {
- n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile)
+ n, err := Verify(&verifyParams)
if err != nil && err != io.EOF {
t.Errorf("Verification failed when expected to succeed: %v", err)
}
@@ -348,7 +439,7 @@ func TestVerify(t *testing.T) {
t.Errorf("Incorrect output buf from Verify")
}
} else {
- if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil {
+ if _, err := Verify(&verifyParams); err == nil {
t.Errorf("Verification succeeded when expected to fail")
}
}
@@ -368,16 +459,26 @@ func TestVerifyRandom(t *testing.T) {
for _, dataAndTreeInSameFile := range []bool{false, true} {
var tree bytesReadWriter
- var root []byte
- var err error
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+
if dataAndTreeInSameFile {
tree.Write(data)
- root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile)
+ genParams.File = &tree
} else {
- root, err = Generate(&bytesReadWriter{
+ genParams.File = &bytesReadWriter{
bytes: data,
- }, int64(dataSize), &tree, &tree, dataAndTreeInSameFile)
+ }
}
+ root, err := Generate(&genParams)
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
@@ -387,9 +488,24 @@ func TestVerifyRandom(t *testing.T) {
size := rand.Int63n(dataSize) + 1
var buf bytes.Buffer
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ ReadOffset: start,
+ ReadSize: size,
+ ExpectedRoot: root,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+
// Checks that the random portion of data from the original data is
// verified successfully.
- n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile)
+ n, err := Verify(&verifyParams)
if err != nil && err != io.EOF {
t.Errorf("Verification failed for correct data: %v", err)
}
@@ -406,13 +522,22 @@ func TestVerifyRandom(t *testing.T) {
t.Errorf("Incorrect output buf from Verify")
}
+ // Verify that modified metadata should fail verification.
buf.Reset()
+ verifyParams.Name = defaultName + "abc"
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verify succeeded for modified metadata, expect failure")
+ }
+
// Flip a random bit in randPortion, and check that verification fails.
+ buf.Reset()
randBytePos := rand.Int63n(size)
data[start+randBytePos] ^= 1
+ verifyParams.File = bytes.NewReader(data)
+ verifyParams.Name = defaultName
- if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil {
- t.Errorf("Verification succeeded for modified data")
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verification succeeded for modified data, expect failure")
}
}
}
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index bdef7762c..e828894b0 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -49,7 +49,7 @@ go_test(
library = ":seccomp",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/bpf",
+ "//pkg/usermem",
],
)
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index 23f30678d..e1444d18b 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -28,17 +28,10 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-type seccompData struct {
- nr uint32
- arch uint32
- instructionPointer uint64
- args [6]uint64
-}
-
// newVictim makes a victim binary.
func newVictim() (string, error) {
f, err := ioutil.TempFile("", "victim")
@@ -58,9 +51,14 @@ func newVictim() (string, error) {
return path, nil
}
-// asInput converts a seccompData to a bpf.Input.
-func (d *seccompData) asInput() bpf.Input {
- return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian}
+// dataAsInput converts a linux.SeccompData to a bpf.Input.
+func dataAsInput(d *linux.SeccompData) bpf.Input {
+ buf := make([]byte, d.SizeBytes())
+ d.MarshalUnsafe(buf)
+ return bpf.InputBytes{
+ Data: buf,
+ Order: usermem.ByteOrder,
+ }
}
func TestBasic(t *testing.T) {
@@ -69,7 +67,7 @@ func TestBasic(t *testing.T) {
desc string
// data is the input data.
- data seccompData
+ data linux.SeccompData
// want is the expected return value of the BPF program.
want linux.BPFAction
@@ -95,12 +93,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "syscall allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "syscall disallowed",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -131,22 +129,22 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed (1a)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (1b)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "syscall 1 matched 2nd rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "no match",
- data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
@@ -168,42 +166,42 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed (1)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (3)",
- data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 3, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "allowed (5)",
- data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 5, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed (0)",
- data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (2)",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (4)",
- data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 4, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (6)",
- data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 6, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "disallowed (100)",
- data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 100, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -223,7 +221,7 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arch (123)",
- data: seccompData{nr: 1, arch: 123},
+ data: linux.SeccompData{Nr: 1, Arch: 123},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
@@ -243,7 +241,7 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "action trap",
- data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
+ data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -268,12 +266,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xe}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -300,12 +298,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "match first rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "match 2nd rule",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xe}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -331,28 +329,28 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "argument allowed (all match)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "argument disallowed (one mismatch)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "argument disallowed (multiple mismatch)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -379,28 +377,28 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (one equal)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (all equal)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
- args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
+ Args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -429,27 +427,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -474,27 +472,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (first arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -522,27 +520,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -567,32 +565,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed (both greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg allowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (second arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg smaller)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xa000ffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -620,27 +618,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -665,32 +663,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (first arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -718,27 +716,27 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "high 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits greater",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "high 32bits equal, low 32bits equal",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits equal, low 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "high 32bits less",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -764,32 +762,32 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (first arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (first arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg allowed (second arg equal)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (second arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (both arg greater)",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -816,51 +814,51 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "arg allowed (low order mandatory bit)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000001
- args: [6]uint64{0x1},
+ Args: [6]uint64{0x1},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg allowed (low order optional bit)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000101
- args: [6]uint64{0x5},
+ Args: [6]uint64{0x5},
},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "arg disallowed (lowest order bit not set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000010
- args: [6]uint64{0x2},
+ Args: [6]uint64{0x2},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (second lowest order bit set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000000 00000011
- args: [6]uint64{0x3},
+ Args: [6]uint64{0x3},
},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "arg disallowed (8th bit set)",
- data: seccompData{
- nr: 1,
- arch: LINUX_AUDIT_ARCH,
+ data: linux.SeccompData{
+ Nr: 1,
+ Arch: LINUX_AUDIT_ARCH,
// 00000000 00000000 00000001 00000000
- args: [6]uint64{0x100},
+ Args: [6]uint64{0x100},
},
want: linux.SECCOMP_RET_TRAP,
},
@@ -885,12 +883,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "allowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x7aabbccdd},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "disallowed",
- data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344},
+ data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x711223344},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -906,7 +904,7 @@ func TestBasic(t *testing.T) {
t.Fatalf("bpf.Compile() got error: %v", err)
}
for _, spec := range test.specs {
- got, err := bpf.Exec(p, spec.data.asInput())
+ got, err := bpf.Exec(p, dataAsInput(&spec.data))
if err != nil {
t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err)
}
@@ -947,8 +945,8 @@ func TestRandom(t *testing.T) {
t.Fatalf("bpf.Compile() got error: %v", err)
}
for i := uint32(0); i < 200; i++ {
- data := seccompData{nr: i, arch: LINUX_AUDIT_ARCH}
- got, err := bpf.Exec(p, data.asInput())
+ data := linux.SeccompData{Nr: int32(i), Arch: LINUX_AUDIT_ARCH}
+ got, err := bpf.Exec(p, dataAsInput(&data))
if err != nil {
t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i)
continue
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 0f433ee79..fd73751e7 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -154,6 +154,7 @@ func (s State) Proto() *rpb.Registers {
Sp: s.Regs.Sp,
Pc: s.Regs.Pc,
Pstate: s.Regs.Pstate,
+ Tls: s.Regs.TPIDR_EL0,
}
return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}}
}
@@ -232,6 +233,7 @@ func (s *State) RegisterMap() (map[string]uintptr, error) {
"Sp": uintptr(s.Regs.Sp),
"Pc": uintptr(s.Regs.Pc),
"Pstate": uintptr(s.Regs.Pstate),
+ "Tls": uintptr(s.Regs.TPIDR_EL0),
}, nil
}
diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto
index 60c027aab..2727ba08a 100644
--- a/pkg/sentry/arch/registers.proto
+++ b/pkg/sentry/arch/registers.proto
@@ -83,6 +83,7 @@ message ARM64Registers {
uint64 sp = 32;
uint64 pc = 33;
uint64 pstate = 34;
+ uint64 tls = 35;
}
message Registers {
oneof arch {
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
index 71c59287c..14a8bf9cd 100644
--- a/pkg/sentry/devices/tundev/BUILD
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -17,6 +17,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 0b701a289..655ea549b 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -16,6 +16,8 @@
package tundev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -26,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -84,7 +87,16 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fd.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := fd.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := fd.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 9379a4d7b..6b7b451b8 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket/netstack",
"//pkg/syserror",
"//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
index 5f8c9b5a2..19ffdec47 100644
--- a/pkg/sentry/fs/dev/net_tun.go
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -15,6 +15,8 @@
package dev
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -25,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -60,7 +63,7 @@ func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMod
}
// GetFile implements fs.InodeOperations.GetFile.
-func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+func (*netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
}
@@ -80,12 +83,12 @@ type netTunFileOperations struct {
var _ fs.FileOperations = (*netTunFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (fops *netTunFileOperations) Release(ctx context.Context) {
- fops.device.Release(ctx)
+func (n *netTunFileOperations) Release(ctx context.Context) {
+ n.device.Release(ctx)
}
// Ioctl implements fs.FileOperations.Ioctl.
-func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
request := args[1].Uint()
data := args[2].Pointer()
@@ -109,16 +112,25 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
return 0, err
}
flags := usermem.ByteOrder.Uint16(req.Data[:])
- return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+ created, err := n.device.SetIff(stack.Stack, req.Name(), flags)
+ if err == nil && created {
+ // Always start with an ARP address for interfaces so they can handle ARP
+ // packets.
+ nicID := n.device.NICID()
+ if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID))
+ }
+ }
+ return 0, err
case linux.TUNGETIFF:
var req linux.IFReq
- copy(req.IFName[:], fops.device.Name())
+ copy(req.IFName[:], n.device.Name())
// Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
// there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
- flags := fops.device.Flags() | linux.IFF_NOFILTER
+ flags := n.device.Flags() | linux.IFF_NOFILTER
usermem.ByteOrder.PutUint16(req.Data[:], flags)
_, err := req.CopyOut(t, data)
@@ -130,41 +142,41 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u
}
// Write implements fs.FileOperations.Write.
-func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+func (n *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
data := make([]byte, src.NumBytes())
if _, err := src.CopyIn(ctx, data); err != nil {
return 0, err
}
- return fops.device.Write(data)
+ return n.device.Write(data)
}
// Read implements fs.FileOperations.Read.
-func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- data, err := fops.device.Read()
+func (n *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := n.device.Read()
if err != nil {
return 0, err
}
- n, err := dst.CopyOut(ctx, data)
- if n > 0 && n < len(data) {
+ bytesCopied, err := dst.CopyOut(ctx, data)
+ if bytesCopied > 0 && bytesCopied < len(data) {
// Not an error for partial copying. Packet truncated.
err = nil
}
- return int64(n), err
+ return int64(bytesCopied), err
}
// Readiness implements watier.Waitable.Readiness.
-func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return fops.device.Readiness(mask)
+func (n *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return n.device.Readiness(mask)
}
// EventRegister implements watier.Waitable.EventRegister.
-func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- fops.device.EventRegister(e, mask)
+func (n *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ n.device.EventRegister(e, mask)
}
// EventUnregister implements watier.Waitable.EventUnregister.
-func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
- fops.device.EventUnregister(e)
+func (n *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ n.device.EventUnregister(e)
}
// isNetTunSupported returns whether /dev/net/tun device is supported for s.
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 9eb6f522e..85a23432b 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/kernel/time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -444,7 +443,7 @@ func (c *CachingInodeOperations) TouchAccessTime(ctx context.Context, inode *fs.
// time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchAccessTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchAccessTimeLocked(now ktime.Time) {
c.attr.AccessTime = now
c.dirtyAttr.AccessTime = true
}
@@ -461,7 +460,7 @@ func (c *CachingInodeOperations) TouchModificationAndStatusChangeTime(ctx contex
// and status change times in-place to the current time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now ktime.Time) {
c.attr.ModificationTime = now
c.dirtyAttr.ModificationTime = true
c.attr.StatusChangeTime = now
@@ -480,7 +479,7 @@ func (c *CachingInodeOperations) TouchStatusChangeTime(ctx context.Context) {
// in-place to the current time.
//
// Preconditions: c.attrMu is locked for writing.
-func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now time.Time) {
+func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now ktime.Time) {
c.attr.StatusChangeTime = now
c.dirtyAttr.StatusChangeTime = true
}
@@ -672,9 +671,6 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Continue.
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
}
-
- default:
- break
}
}
unlock()
@@ -768,9 +764,6 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
// Continue.
seg, gap = gap.NextSegment(), FileRangeGapIterator{}
-
- default:
- break
}
}
rw.maybeGrowFile()
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index 323506d33..be0900030 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -122,7 +122,7 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
// remainingTraversals is not configurable in VFS2, all callers are using the
// default anyways.
func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
- vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
+ vfsObj := l.root.Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
path := fspath.Parse(pathname)
pop := &vfs.PathOperation{
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index abc610ef3..7b1eec3da 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/fd",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
@@ -86,9 +88,9 @@ go_test(
library = ":ext",
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/marshal/primitive",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index 8bb104ff0..1165234f9 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -18,7 +18,7 @@ import (
"io"
"math"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -34,19 +34,19 @@ type blockMapFile struct {
// directBlks are the direct blocks numbers. The physical blocks pointed by
// these holds file data. Contains file blocks 0 to 11.
- directBlks [numDirectBlks]uint32
+ directBlks [numDirectBlks]primitive.Uint32
// indirectBlk is the physical block which contains (blkSize/4) direct block
// numbers (as uint32 integers).
- indirectBlk uint32
+ indirectBlk primitive.Uint32
// doubleIndirectBlk is the physical block which contains (blkSize/4) indirect
// block numbers (as uint32 integers).
- doubleIndirectBlk uint32
+ doubleIndirectBlk primitive.Uint32
// tripleIndirectBlk is the physical block which contains (blkSize/4) doubly
// indirect block numbers (as uint32 integers).
- tripleIndirectBlk uint32
+ tripleIndirectBlk primitive.Uint32
// coverage at (i)th index indicates the amount of file data a node at
// height (i) covers. Height 0 is the direct block.
@@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
}
blkMap := file.regFile.inode.diskInode.Data()
- binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
- binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
- binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk)
+ for i := 0; i < numDirectBlks; i++ {
+ file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4])
+ }
+ file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4])
+ file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4])
+ file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4])
return file, nil
}
@@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) {
switch {
case offset < dirBlksEnd:
// Direct block.
- curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:])
+ curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:])
case offset < indirBlkEnd:
// Indirect block.
- curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:])
+ curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:])
case offset < doubIndirBlkEnd:
// Doubly indirect block.
- curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:])
+ curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:])
default:
// Triply indirect block.
- curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:])
+ curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:])
}
read += curR
@@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
read := 0
curChildOff := relFileOff % childCov
for i := startIdx; i < endIdx; i++ {
- var childPhyBlk uint32
+ var childPhyBlk primitive.Uint32
err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
if err != nil {
return read, err
}
- n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:])
+ n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:])
read += n
if err != nil {
return read, err
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 6fa84e7aa..ed98b482e 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -20,7 +20,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
var fileData []byte
blkNums := newBlkNumGen()
- var data []byte
+ off := 0
+ data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes())
// Write the direct blocks.
for i := 0; i < numDirectBlks; i++ {
- curBlkNum := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, curBlkNum)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(data[off:])
+ off += curBlkNum.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...)
}
// Write to indirect block.
- indirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, indirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...)
-
- // Write to indirect block.
- doublyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...)
-
- // Write to indirect block.
- triplyIndirectBlk := blkNums.next()
- data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
- fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
+ indirectBlk := primitive.Uint32(blkNums.next())
+ indirectBlk.MarshalBytes(data[off:])
+ off += indirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...)
+
+ // Write to double indirect block.
+ doublyIndirectBlk := primitive.Uint32(blkNums.next())
+ doublyIndirectBlk.MarshalBytes(data[off:])
+ off += doublyIndirectBlk.SizeBytes()
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...)
+
+ // Write to triple indirect block.
+ triplyIndirectBlk := primitive.Uint32(blkNums.next())
+ triplyIndirectBlk.MarshalBytes(data[off:])
+ fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...)
args := inodeArgs{
fs: &filesystem{
@@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN
var fileData []byte
for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 {
- curBlkNum := blkNums.next()
- copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum))
- fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...)
+ curBlkNum := primitive.Uint32(blkNums.next())
+ curBlkNum.MarshalBytes(disk[off : off+4])
+ fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...)
}
return fileData
}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 452450d82..0ad79b381 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -16,7 +16,6 @@ package ext
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -100,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
} else {
curDirent.diskDirent = &disklayout.DirentOld{}
}
- binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent)
+ curDirent.diskDirent.UnmarshalBytes(buf)
if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 {
// Inode number and name length fields being set to 0 is used to indicate
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 9bd9c76c0..d98a05dd8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -22,10 +22,11 @@ go_library(
"superblock_old.go",
"test_utils.go",
],
+ marshal = True,
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/marshal",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
index ad6f4fef8..0d56ae9da 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// BlockGroup represents a Linux ext block group descriptor. An ext file system
// is split into a series of block groups. This provides an access layer to
// information needed to access and use a block group.
@@ -30,6 +34,8 @@ package disklayout
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors.
type BlockGroup interface {
+ marshal.Marshallable
+
// InodeTable returns the absolute block number of the block containing the
// inode table. This points to an array of Inode structs. Inode tables are
// statically allocated at mkfs time. The superblock records the number of
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
index 3e16c76db..a35fa22a0 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go
@@ -17,6 +17,8 @@ package disklayout
// BlockGroup32Bit emulates the first half of struct ext4_group_desc in
// fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and
// 32-bit ext4 filesystems. It implements BlockGroup interface.
+//
+// +marshal
type BlockGroup32Bit struct {
BlockBitmapLo uint32
InodeBitmapLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
index 9a809197a..d54d1d345 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go
@@ -18,6 +18,8 @@ package disklayout
// It is the block group descriptor struct for 64-bit ext4 filesystems.
// It implements BlockGroup interface. It is an extension of the 32-bit
// version of BlockGroup.
+//
+// +marshal
type BlockGroup64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
index 0ef4294c0..e4ce484e4 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go
@@ -21,6 +21,8 @@ import (
// TestBlockGroupSize tests that the block group descriptor structs are of the
// correct size.
func TestBlockGroupSize(t *testing.T) {
- assertSize(t, BlockGroup32Bit{}, 32)
- assertSize(t, BlockGroup64Bit{}, 64)
+ var bgSmall BlockGroup32Bit
+ assertSize(t, &bgSmall, 32)
+ var bgBig BlockGroup64Bit
+ assertSize(t, &bgBig, 64)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
index 417b6cf65..568c8cb4c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go
@@ -15,6 +15,7 @@
package disklayout
import (
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -51,6 +52,8 @@ var (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories.
type Dirent interface {
+ marshal.Marshallable
+
// Inode returns the absolute inode number of the underlying inode.
// Inode number 0 signifies an unused dirent.
Inode() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
index 29ae4a5c2..51f9c2946 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go
@@ -29,12 +29,14 @@ import (
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentNew struct {
InodeNumber uint32
RecordLength uint16
NameLength uint8
FileTypeRaw uint8
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentNew implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
index 6fff12a6e..d4b19e086 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go
@@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs"
// Note: This struct can be of variable size on disk. The one described below
// is of maximum size and the FileName beyond NameLength bytes might contain
// garbage.
+//
+// +marshal
type DirentOld struct {
InodeNumber uint32
RecordLength uint16
NameLength uint16
- FileNameRaw [MaxFileName]byte
+ FileNameRaw [MaxFileName]byte `marshal:"unaligned"`
}
// Compiles only if DirentOld implements Dirent.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
index 934919f8a..3486864dc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go
@@ -21,6 +21,8 @@ import (
// TestDirentSize tests that the dirent structs are of the correct
// size.
func TestDirentSize(t *testing.T) {
- assertSize(t, DirentOld{}, uintptr(DirentSize))
- assertSize(t, DirentNew{}, uintptr(DirentSize))
+ var dOld DirentOld
+ assertSize(t, &dOld, DirentSize)
+ var dNew DirentNew
+ assertSize(t, &dNew, DirentSize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
index bdf4e2132..0834e9ba8 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go
@@ -36,8 +36,6 @@
// escape analysis on an unknown implementation at compile time.
//
// Notes:
-// - All fields in these structs are exported because binary.Read would
-// panic otherwise.
// - All structures on disk are in little-endian order. Only jbd2 (journal)
// structures are in big-endian order.
// - All OS dependent fields in these structures will be interpretted using
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 4110649ab..b13999bfc 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
// Extents were introduced in ext4 and provide huge performance gains in terms
// data locality and reduced metadata block usage. Extents are organized in
// extent trees. The root node is contained in inode.BlocksRaw.
@@ -64,6 +68,8 @@ type ExtentNode struct {
// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
+ marshal.Marshallable
+
// FileBlock returns the first file block number covered by this entry.
FileBlock() uint32
@@ -75,6 +81,8 @@ type ExtentEntry interface {
// tree node begins with this and is followed by `NumEntries` number of:
// - Extent if `Depth` == 0
// - ExtentIdx otherwise
+//
+// +marshal
type ExtentHeader struct {
// Magic in the extent magic number, must be 0xf30a.
Magic uint16
@@ -96,6 +104,8 @@ type ExtentHeader struct {
// internal nodes. Sorted in ascending order based on FirstFileBlock since
// Linux does a binary search on this. This points to a block containing the
// child node.
+//
+// +marshal
type ExtentIdx struct {
FirstFileBlock uint32
ChildBlockLo uint32
@@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 {
// nodes. Sorted in ascending order based on FirstFileBlock since Linux does a
// binary search on this. This points to an array of data blocks containing the
// file data. It covers `Length` data blocks starting from `StartBlock`.
+//
+// +marshal
type Extent struct {
FirstFileBlock uint32
Length uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index 8762b90db..c96002e19 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,10 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentHeaderSize)
- assertSize(t, ExtentIdx{}, ExtentEntrySize)
- assertSize(t, Extent{}, ExtentEntrySize)
+ var h ExtentHeader
+ assertSize(t, &h, ExtentHeaderSize)
+ var i ExtentIdx
+ assertSize(t, &i, ExtentEntrySize)
+ var e Extent
+ assertSize(t, &e, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go
index 88ae913f5..ef25040a9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go
@@ -16,6 +16,7 @@ package disklayout
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
@@ -38,6 +39,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes.
type Inode interface {
+ marshal.Marshallable
+
// Mode returns the linux file mode which is majorly used to extract
// information like:
// - File permissions (read/write/execute by user/group/others).
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
index 8f9f574ce..a4503f5cf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go
@@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time"
// are used to provide nanoscond precision. Hence, these timestamps will now
// overflow in May 2446.
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps.
+//
+// +marshal
type InodeNew struct {
InodeOld
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
index db25b11b6..e6b28babf 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go
@@ -30,6 +30,8 @@ const (
//
// All fields representing time are in seconds since the epoch. Which means that
// they will overflow in January 2038.
+//
+// +marshal
type InodeOld struct {
ModeRaw uint16
UIDLo uint16
diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
index dd03ee50e..90744e956 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go
@@ -24,10 +24,12 @@ import (
// TestInodeSize tests that the inode structs are of the correct size.
func TestInodeSize(t *testing.T) {
- assertSize(t, InodeOld{}, OldInodeSize)
+ var iOld InodeOld
+ assertSize(t, &iOld, OldInodeSize)
// This was updated from 156 bytes to 160 bytes in Oct 2015.
- assertSize(t, InodeNew{}, 160)
+ var iNew InodeNew
+ assertSize(t, &iNew, 160)
}
// TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
index 8bb327006..70948ebe9 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go
@@ -14,6 +14,10 @@
package disklayout
+import (
+ "gvisor.dev/gvisor/pkg/marshal"
+)
+
const (
// SbOffset is the absolute offset at which the superblock is placed.
SbOffset = 1024
@@ -38,6 +42,8 @@ const (
//
// See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block.
type SuperBlock interface {
+ marshal.Marshallable
+
// InodesCount returns the total number of inodes in this filesystem.
InodesCount() uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
index 53e515fd3..4dc6080fb 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go
@@ -17,6 +17,8 @@ package disklayout
// SuperBlock32Bit implements SuperBlock and represents the 32-bit version of
// the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if
// RevLevel = DynamicRev and 64-bit feature is disabled.
+//
+// +marshal
type SuperBlock32Bit struct {
// We embed the old superblock struct here because the 32-bit version is just
// an extension of the old version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
index 7c1053fb4..2c9039327 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go
@@ -19,6 +19,8 @@ package disklayout
// 1024 bytes (smallest possible block size) and hence the superblock always
// fits in no more than one data block. Should only be used when the 64-bit
// feature is set.
+//
+// +marshal
type SuperBlock64Bit struct {
// We embed the 32-bit struct here because 64-bit version is just an extension
// of the 32-bit version.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
index 9221e0251..e4709f23c 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go
@@ -16,6 +16,8 @@ package disklayout
// SuperBlockOld implements SuperBlock and represents the old version of the
// superblock struct. Should be used only if RevLevel = OldRev.
+//
+// +marshal
type SuperBlockOld struct {
InodesCountRaw uint32
BlocksCountLo uint32
diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
index 463b5ba21..b734b6987 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go
@@ -21,7 +21,10 @@ import (
// TestSuperBlockSize tests that the superblock structs are of the correct
// size.
func TestSuperBlockSize(t *testing.T) {
- assertSize(t, SuperBlockOld{}, 84)
- assertSize(t, SuperBlock32Bit{}, 336)
- assertSize(t, SuperBlock64Bit{}, 1024)
+ var sbOld SuperBlockOld
+ assertSize(t, &sbOld, 84)
+ var sb32 SuperBlock32Bit
+ assertSize(t, &sb32, 336)
+ var sb64 SuperBlock64Bit
+ assertSize(t, &sb64, 1024)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
index 9c63f04c0..a4bc08411 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go
@@ -18,13 +18,13 @@ import (
"reflect"
"testing"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
)
-func assertSize(t *testing.T, v interface{}, want uintptr) {
+func assertSize(t *testing.T, v marshal.Marshallable, want int) {
t.Helper()
- if got := binary.Size(v); got != want {
+ if got := v.SizeBytes(); got != want {
t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got)
}
}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 04917d762..778460107 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -18,7 +18,6 @@ import (
"io"
"sort"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -60,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
+ f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize])
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -79,7 +78,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
+ curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize])
f.root.Entries[i].Entry = curEntry
}
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index cd10d46ee..985f76ac0 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -21,7 +21,6 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
)
@@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
// writeTree writes the tree represented by `root` to the inode and disk. It
// also writes random file data on disk.
func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte {
- rootData := binary.Marshal(nil, binary.LittleEndian, root.Header)
+ rootData := in.diskInode.Data()
+ root.Header.MarshalBytes(rootData)
+ off := root.Header.SizeBytes()
for _, ep := range root.Entries {
- rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(rootData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(in.diskInode.Data(), rootData)
-
var fileData []byte
for _, ep := range root.Entries {
if root.Header.Height == 0 {
@@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl
// writeTreeToDisk is the recursive step for writeTree which writes the tree
// on the disk only. Also writes random file data on disk.
func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte {
- nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header)
+ nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:]
+ curNode.Node.Header.MarshalBytes(nodeData)
+ off := curNode.Node.Header.SizeBytes()
for _, ep := range curNode.Node.Entries {
- nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry)
+ ep.Entry.MarshalBytes(nodeData[off:])
+ off += ep.Entry.SizeBytes()
}
- copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData)
-
var fileData []byte
for _, ep := range curNode.Node.Entries {
if curNode.Node.Header.Height == 0 {
diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go
index d8b728f8c..58ef7b9b8 100644
--- a/pkg/sentry/fsimpl/ext/utils.go
+++ b/pkg/sentry/fsimpl/ext/utils.go
@@ -17,21 +17,21 @@ package ext
import (
"io"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/syserror"
)
// readFromDisk performs a binary read from disk into the given struct from
// the absolute offset provided.
-func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error {
- n := binary.Size(v)
+func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error {
+ n := v.SizeBytes()
buf := make([]byte, n)
if read, _ := dev.ReadAt(buf, abOff); read < int(n) {
return syserror.EIO
}
- binary.Unmarshal(buf, binary.LittleEndian, v)
+ v.UnmarshalBytes(buf)
return nil
}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 131145b85..8a447e29f 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -348,10 +348,10 @@ func (e *SCMConnectedEndpoint) Init() error {
func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
e.DecRef(func() {
e.mu.Lock()
+ fdnotifier.RemoveFD(int32(e.fd))
if err := syscall.Close(e.fd); err != nil {
log.Warningf("Failed to close host fd %d: %v", err)
}
- fdnotifier.RemoveFD(int32(e.fd))
e.destroyLocked()
e.mu.Unlock()
})
diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD
index 067c1657f..adb610213 100644
--- a/pkg/sentry/fsimpl/signalfd/BUILD
+++ b/pkg/sentry/fsimpl/signalfd/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/sentry/kernel",
"//pkg/sentry/vfs",
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index bf11b425a..10f1452ef 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -16,7 +16,6 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -95,8 +94,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
}
// Copy out the signal info using the specified format.
- var buf [128]byte
- binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ infoNative := linux.SignalfdSiginfo{
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
@@ -105,9 +103,13 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
- })
- n, err := dst.CopyOut(ctx, buf[:])
- return int64(n), err
+ }
+ n, err := infoNative.WriteTo(dst.Writer(ctx))
+ if err == usermem.ErrEndOfIOSequence {
+ // Partial copy-out ok.
+ err = nil
+ }
+ return n, err
}
// Readiness implements waiter.Waitable.Readiness.
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index b75d70ae6..1a6749e53 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -104,7 +104,7 @@ func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
func (fd *kcovFD) Release(ctx context.Context) {
// kcov instances have reference counts in Linux, but this seems sufficient
// for our purposes.
- fd.kcov.Reset()
+ fd.kcov.Clear()
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index bc8e38431..0ca750281 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
licenses(["notice"])
@@ -26,3 +26,22 @@ go_library(
"//pkg/usermem",
],
)
+
+go_test(
+ name = "verity_test",
+ srcs = [
+ "verity_test.go",
+ ],
+ library = ":verity",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 26b117ca4..7779271a9 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -20,6 +20,7 @@ import (
"io"
"strconv"
"strings"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -251,11 +252,35 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
Ctx: ctx,
}
+ parentStat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parent.lowerVD,
+ Start: parent.lowerVD,
+ }, &vfs.StatOptions{})
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ }
+ if err != nil {
+ return nil, err
+ }
+
// Since we are verifying against a directory Merkle tree, buf should
// contain the root hash of the children in the parent Merkle tree when
// Verify returns with success.
var buf bytes.Buffer
- if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF {
+ if _, err := merkletree.Verify(&merkletree.VerifyParams{
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ ReadOffset: int64(offset),
+ ReadSize: int64(merkletree.DigestSize()),
+ ExpectedRoot: parent.rootHash,
+ DataAndTreeInSameFile: true,
+ }); err != nil && err != io.EOF {
return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
@@ -266,6 +291,84 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
return child, nil
}
+// verifyStat verifies the stat against the verified root hash. The mode/uid/gid
+// of the file is cached after verified.
+func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Statx) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+
+ // Get the path to the child dentry. This is only used to provide path
+ // information in failure case.
+ childPath, err := vfsObj.PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD)
+ if err != nil {
+ return err
+ }
+
+ verityMu.RLock()
+ defer verityMu.RUnlock()
+
+ fd, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: d.lowerMerkleVD,
+ Start: d.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err == syserror.ENOENT {
+ return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ }
+ if err != nil {
+ return err
+ }
+
+ merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfStringInt32,
+ })
+
+ if err == syserror.ENODATA {
+ return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ }
+ if err != nil {
+ return err
+ }
+
+ size, err := strconv.Atoi(merkleSize)
+ if err != nil {
+ return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ }
+
+ fdReader := vfs.FileReadWriteSeeker{
+ FD: fd,
+ Ctx: ctx,
+ }
+
+ var buf bytes.Buffer
+ params := &merkletree.VerifyParams{
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ ReadOffset: 0,
+ // Set read size to 0 so only the metadata is verified.
+ ReadSize: 0,
+ ExpectedRoot: d.rootHash,
+ DataAndTreeInSameFile: false,
+ }
+ if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR {
+ params.DataAndTreeInSameFile = true
+ }
+
+ if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
+ return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ }
+ d.mode = uint32(stat.Mode)
+ d.uid = stat.UID
+ d.gid = stat.GID
+ return nil
+}
+
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
@@ -274,9 +377,27 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
// runtime enable is allowed and the parent directory is
// enabled, we should verify the child root hash here because
// it may be cached before enabled.
- if fs.allowRuntimeEnable && len(parent.rootHash) != 0 {
- if _, err := fs.verifyChild(ctx, parent, child); err != nil {
- return nil, err
+ if fs.allowRuntimeEnable {
+ if isEnabled(parent) {
+ if _, err := fs.verifyChild(ctx, parent, child); err != nil {
+ return nil, err
+ }
+ }
+ if isEnabled(child) {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
+ stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: child.lowerVD,
+ Start: child.lowerVD,
+ }, &vfs.StatOptions{
+ Mask: mask,
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := fs.verifyStat(ctx, child, stat); err != nil {
+ return nil, err
+ }
}
}
return child, nil
@@ -426,7 +547,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
child.parent = parent
child.name = name
- // TODO(b/162788573): Verify child metadata.
child.mode = uint32(stat.Mode)
child.uid = stat.UID
child.gid = stat.GID
@@ -434,12 +554,18 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// Verify child root hash. This should always be performed unless in
// allowRuntimeEnable mode and the parent directory hasn't been enabled
// yet.
- if !(fs.allowRuntimeEnable && len(parent.rootHash) == 0) {
+ if isEnabled(parent) {
if _, err := fs.verifyChild(ctx, parent, child); err != nil {
child.destroyLocked(ctx)
return nil, err
}
}
+ if isEnabled(child) {
+ if err := fs.verifyStat(ctx, child, stat); err != nil {
+ child.destroyLocked(ctx)
+ return nil, err
+ }
+ }
return child, nil
}
@@ -693,22 +819,24 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// be called if a verity FD is created successfully.
defer merkleWriter.DecRef(ctx)
- parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
- Root: d.parent.lowerMerkleVD,
- Start: d.parent.lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_WRONLY | linux.O_APPEND,
- })
- if err != nil {
- if err == syserror.ENOENT {
- parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ if d.parent != nil {
+ parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.parent.lowerMerkleVD,
+ Start: d.parent.lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_APPEND,
+ })
+ if err != nil {
+ if err == syserror.ENOENT {
+ parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
+ return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ }
+ return nil, err
}
- return nil, err
+ // parentMerkleWriter is cleaned up if any error occurs. IncRef
+ // will be called if a verity FD is created successfully.
+ defer parentMerkleWriter.DecRef(ctx)
}
- // parentMerkleWriter is cleaned up if any error occurs. IncRef
- // will be called if a verity FD is created successfully.
- defer parentMerkleWriter.DecRef(ctx)
}
fd := &fileDescription{
@@ -769,6 +897,8 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts
}
// StatAt implements vfs.FilesystemImpl.StatAt.
+// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should
+// be verified.
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
var ds *[]*dentry
fs.renameMu.RLock()
@@ -786,6 +916,11 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return linux.Statx{}, err
}
+ if isEnabled(d) {
+ if err := fs.verifyStat(ctx, d, stat); err != nil {
+ return linux.Statx{}, err
+ }
+ }
return stat, nil
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 3129f290d..3eb972237 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -142,6 +142,14 @@ func (FilesystemType) Name() string {
return Name
}
+// isEnabled checks whether the target is enabled with verity features. It
+// should always be true if runtime enable is not allowed. In runtime enable
+// mode, it returns true if the target has been enabled with
+// ioctl(FS_IOC_ENABLE_VERITY).
+func isEnabled(d *dentry) bool {
+ return !d.fs.allowRuntimeEnable || len(d.rootHash) != 0
+}
+
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In
// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
@@ -245,12 +253,17 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
- // TODO(b/162788573): Verify Metadata.
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
-
d.rootHash = make([]byte, len(iopts.RootHash))
+
+ if !fs.allowRuntimeEnable {
+ if err := fs.verifyStat(ctx, d, stat); err != nil {
+ return nil, nil, err
+ }
+ }
+
copy(d.rootHash, iopts.RootHash)
d.vfsd.Init(d)
@@ -488,6 +501,11 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
if err != nil {
return linux.Statx{}, err
}
+ if isEnabled(fd.d) {
+ if err := fd.d.fs.verifyStat(ctx, fd.d, stat); err != nil {
+ return linux.Statx{}, err
+ }
+ }
return stat, nil
}
@@ -516,8 +534,10 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
FD: fd.merkleWriter,
Ctx: ctx,
}
- var rootHash []byte
- var dataSize uint64
+ params := &merkletree.GenerateParams{
+ TreeReader: &merkleReader,
+ TreeWriter: &merkleWriter,
+ }
switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
case linux.S_IFREG:
@@ -528,12 +548,14 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
if err != nil {
return nil, 0, err
}
- dataSize = stat.Size
- rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */)
- if err != nil {
- return nil, 0, err
- }
+ params.File = &fdReader
+ params.Size = int64(stat.Size)
+ params.Name = fd.d.name
+ params.Mode = uint32(stat.Mode)
+ params.UID = stat.UID
+ params.GID = stat.GID
+ params.DataAndTreeInSameFile = false
case linux.S_IFDIR:
// For a directory, generate a Merkle tree based on the root
// hashes of its children that has already been written to the
@@ -542,23 +564,32 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
if err != nil {
return nil, 0, err
}
- dataSize = merkleStat.Size
- rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */)
+ params.Size = int64(merkleStat.Size)
+
+ stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{})
if err != nil {
return nil, 0, err
}
+
+ params.File = &merkleReader
+ params.Name = fd.d.name
+ params.Mode = uint32(stat.Mode)
+ params.UID = stat.UID
+ params.GID = stat.GID
+ params.DataAndTreeInSameFile = true
default:
// TODO(b/167728857): Investigate whether and how we should
// enable other types of file.
return nil, 0, syserror.EINVAL
}
- return rootHash, dataSize, nil
+ rootHash, err := merkletree.Generate(params)
+ return rootHash, uint64(params.Size), err
}
// enableVerity enables verity features on fd by generating a Merkle tree file
// and stores its root hash in its parent directory's Merkle tree.
-func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) {
if !fd.d.fs.allowRuntimeEnable {
return 0, syserror.EPERM
}
@@ -568,7 +599,11 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg
verityMu.Lock()
defer verityMu.Unlock()
- if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil {
+ // In allowRuntimeEnable mode, the underlying fd and read/write fd for
+ // the Merkle tree file should have all been initialized. For any file
+ // or directory other than the root, the parent Merkle tree file should
+ // have also been initialized.
+ if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
}
@@ -577,26 +612,28 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg
return 0, err
}
- stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- return 0, err
- }
+ if fd.parentMerkleWriter != nil {
+ stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
- // Write the root hash of fd to the parent directory's Merkle tree
- // file, as it should be part of the parent Merkle tree data.
- // parentMerkleWriter is open with O_APPEND, so it should write
- // directly to the end of the file.
- if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil {
- return 0, err
- }
+ // Write the root hash of fd to the parent directory's Merkle
+ // tree file, as it should be part of the parent Merkle tree
+ // data. parentMerkleWriter is open with O_APPEND, so it
+ // should write directly to the end of the file.
+ if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil {
+ return 0, err
+ }
- // Record the offset of the root hash of fd in parent directory's
- // Merkle tree file.
- if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
- Name: merkleOffsetInParentXattr,
- Value: strconv.Itoa(int(stat.Size)),
- }); err != nil {
- return 0, err
+ // Record the offset of the root hash of fd in parent directory's
+ // Merkle tree file.
+ if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{
+ Name: merkleOffsetInParentXattr,
+ Value: strconv.Itoa(int(stat.Size)),
+ }); err != nil {
+ return 0, err
+ }
}
// Record the size of the data being hashed for fd.
@@ -610,7 +647,45 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg
return 0, nil
}
-func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+// measureVerity returns the root hash of fd, saved in args[2].
+func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ var metadata linux.DigestMetadata
+
+ // If allowRuntimeEnable is true, an empty fd.d.rootHash indicates that
+ // verity is not enabled for the file. If allowRuntimeEnable is false,
+ // this is an integrity violation because all files should have verity
+ // enabled, in which case fd.d.rootHash should be set.
+ if len(fd.d.rootHash) == 0 {
+ if fd.d.fs.allowRuntimeEnable {
+ return 0, syserror.ENODATA
+ }
+ return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no root hash found")
+ }
+
+ // The first part of VerityDigest is the metadata.
+ if _, err := metadata.CopyIn(t, verityDigest); err != nil {
+ return 0, err
+ }
+ if metadata.DigestSize < uint16(len(fd.d.rootHash)) {
+ return 0, syserror.EOVERFLOW
+ }
+
+ // Populate the output digest size, since DigestSize is both input and
+ // output.
+ metadata.DigestSize = uint16(len(fd.d.rootHash))
+
+ // First copy the metadata.
+ if _, err := metadata.CopyOut(t, verityDigest); err != nil {
+ return 0, err
+ }
+
+ // Now copy the root hash bytes to the memory after metadata.
+ _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.rootHash)
+ return 0, err
+}
+
+func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) {
f := int32(0)
// All enabled files should store a root hash. This flag is not settable
@@ -620,8 +695,7 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar
}
t := kernel.TaskFromContext(ctx)
- addr := args[2].Pointer()
- _, err := primitive.CopyInt32Out(t, addr, f)
+ _, err := primitive.CopyInt32Out(t, flags, f)
return 0, err
}
@@ -629,11 +703,15 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar
func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := args[1].Uint(); cmd {
case linux.FS_IOC_ENABLE_VERITY:
- return fd.enableVerity(ctx, uio, args)
+ return fd.enableVerity(ctx, uio)
+ case linux.FS_IOC_MEASURE_VERITY:
+ return fd.measureVerity(ctx, uio, args[2].Pointer())
case linux.FS_IOC_GETFLAGS:
- return fd.getFlags(ctx, uio, args)
+ return fd.verityFlags(ctx, uio, args[2].Pointer())
default:
- return fd.lowerFD.Ioctl(ctx, uio, args)
+ // TODO(b/169682228): Investigate which ioctl commands should
+ // be allowed.
+ return 0, syserror.ENOSYS
}
}
@@ -641,7 +719,7 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
// No need to verify if the file is not enabled yet in
// allowRuntimeEnable mode.
- if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 {
+ if !isEnabled(fd.d) {
return fd.lowerFD.PRead(ctx, dst, offset, opts)
}
@@ -678,9 +756,22 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
Ctx: ctx,
}
- n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */)
+ n, err := merkletree.Verify(&merkletree.VerifyParams{
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ ReadOffset: offset,
+ ReadSize: dst.NumBytes(),
+ ExpectedRoot: fd.d.rootHash,
+ DataAndTreeInSameFile: false,
+ })
if err != nil {
- return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err))
+ return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
new file mode 100644
index 000000000..8d0926bc4
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -0,0 +1,490 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// rootMerkleFilename is the name of the root Merkle tree file.
+const rootMerkleFilename = "root.verity"
+
+// maxDataSize is the maximum data size written to the file for test.
+const maxDataSize = 100000
+
+// newVerityRoot creates a new verity mount, and returns the root. The
+// underlying file system is tmpfs. If the error is not nil, then cleanup
+// should be called when the root is no longer needed.
+func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+ rand.Seed(time.Now().UnixNano())
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err)
+ }
+
+ vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: InternalFilesystemOptions{
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ AllowRuntimeEnable: true,
+ NoCrashOnVerificationFailure: true,
+ },
+ },
+ })
+ if err != nil {
+ return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err)
+ }
+ root := mntns.Root()
+ t.Helper()
+ t.Cleanup(func() {
+ root.DecRef(ctx)
+ mntns.DecRef(ctx)
+ })
+ return vfsObj, root, nil
+}
+
+// newFileFD creates a new file in the verity mount, and returns the FD. The FD
+// points to a file that has random data generated.
+func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+
+ // Create the file in the underlying file system.
+ lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.ModeRegular | mode,
+ })
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Generate random data to be written to the file.
+ dataSize := rand.Intn(maxDataSize) + 1
+ data := make([]byte, dataSize)
+ rand.Read(data)
+
+ // Write directly to the underlying FD, since verity FD is read-only.
+ n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ if n != int64(len(data)) {
+ return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data))
+ }
+
+ lowerFD.DecRef(ctx)
+
+ // Now open the verity file descriptor.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filePath),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular | mode,
+ })
+ return fd, dataSize, err
+}
+
+// corruptRandomBit randomly flips a bit in the file represented by fd.
+func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
+ // Flip a random bit in the underlying file.
+ randomPos := int64(rand.Intn(size))
+ byteToModify := make([]byte, 1)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PRead: %v", err)
+ }
+ byteToModify[0] ^= 1
+ if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil {
+ return fmt.Errorf("lowerFD.PWrite: %v", err)
+ }
+ return nil
+}
+
+// TestOpen ensures that when a file is created, the corresponding Merkle tree
+// file and the root Merkle tree file exist.
+func TestOpen(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Ensure that the corresponding Merkle tree file is created.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ }
+
+ // Ensure the root merkle tree file is created.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + rootMerkleFilename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ }
+}
+
+// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file
+// succeeds after enabling verity for it.
+func TestReadUnmodifiedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
+}
+
+// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
+// succeeds after enabling verity for it.
+func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Ensure reopening the verity enabled file succeeds.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != nil {
+ t.Errorf("reopen enabled file failed: %v", err)
+ }
+}
+
+// TestModifiedFileFails ensures that read from a modified verity file fails.
+func TestModifiedFileFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.PRead succeeded with modified file")
+ }
+}
+
+// TestModifiedMerkleFails ensures that read from a verity file fails if the
+// corresponding Merkle tree file is modified.
+func TestModifiedMerkleFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
+
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the Merkle tree file.
+ stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ merkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from a file with modified Merkle tree fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ fmt.Println(buf)
+ t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ }
+}
+
+// TestModifiedParentMerkleFails ensures that open a verity enabled file in a
+// verity enabled directory fails if the hashes related to the target file in
+// the parent Merkle tree file is modified.
+func TestModifiedParentMerkleFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Enable verity on the parent directory.
+ parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
+
+ parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: parentLowerMerkleVD,
+ Start: parentLowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the parent Merkle tree file.
+ // This parent directory contains only one child, so any random
+ // modification in the parent Merkle tree should cause verification
+ // failure when opening the child file.
+ stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ parentMerkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ parentLowerMerkleFD.DecRef(ctx)
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err == nil {
+ t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ }
+}
+
+// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file
+// succeeds after enabling verity for it.
+func TestUnmodifiedStatSucceeds(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms stat succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
+ t.Errorf("fd.Stat: %v", err)
+ }
+}
+
+// TestModifiedStatFails checks that getting stat for a file with modified stat
+// should fail.
+func TestModifiedStatFails(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj, root, err := newVerityRoot(ctx, t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ lowerFD := fd.Impl().(*fileDescription).lowerFD
+ // Change the stat of the underlying file, and check that stat fails.
+ if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: uint32(linux.STATX_MODE),
+ Mode: 0777,
+ },
+ }); err != nil {
+ t.Fatalf("lowerFD.SetStat: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
+ t.Errorf("fd.Stat succeeded when it should fail")
+ }
+}
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 61c78569d..300b7ccce 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -7,11 +7,14 @@ go_library(
srcs = [
"cgroup.go",
"hostmm.go",
+ "membarrier.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/abi/linux",
"//pkg/fd",
"//pkg/log",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/hostmm/membarrier.go b/pkg/sentry/hostmm/membarrier.go
new file mode 100644
index 000000000..4468d75f1
--- /dev/null
+++ b/pkg/sentry/hostmm/membarrier.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostmm
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+var (
+ haveMembarrierGlobal = false
+ haveMembarrierPrivateExpedited = false
+)
+
+func init() {
+ supported, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_QUERY, 0 /* flags */, 0 /* unused */)
+ if e != 0 {
+ if e != syscall.ENOSYS {
+ log.Warningf("membarrier(MEMBARRIER_CMD_QUERY) failed: %s", e.Error())
+ }
+ return
+ }
+ // We don't use MEMBARRIER_CMD_GLOBAL_EXPEDITED because this sends IPIs to
+ // all CPUs running tasks that have previously invoked
+ // MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, which presents a DOS risk.
+ // (MEMBARRIER_CMD_GLOBAL is synchronize_rcu(), i.e. it waits for an RCU
+ // grace period to elapse without bothering other CPUs.
+ // MEMBARRIER_CMD_PRIVATE_EXPEDITED sends IPIs only to CPUs running tasks
+ // sharing the caller's MM.)
+ if supported&linux.MEMBARRIER_CMD_GLOBAL != 0 {
+ haveMembarrierGlobal = true
+ }
+ if req := uintptr(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED | linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED); supported&req == req {
+ if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 {
+ log.Warningf("membarrier(MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED) failed: %s", e.Error())
+ } else {
+ haveMembarrierPrivateExpedited = true
+ }
+ }
+}
+
+// HaveGlobalMemoryBarrier returns true if GlobalMemoryBarrier is supported.
+func HaveGlobalMemoryBarrier() bool {
+ return haveMembarrierGlobal
+}
+
+// GlobalMemoryBarrier blocks until "all running threads [in the host OS] have
+// passed through a state where all memory accesses to user-space addresses
+// match program order between entry to and return from [GlobalMemoryBarrier]",
+// as for membarrier(2).
+//
+// Preconditions: HaveGlobalMemoryBarrier() == true.
+func GlobalMemoryBarrier() error {
+ if _, _, e := syscall.Syscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_GLOBAL, 0 /* flags */, 0 /* unused */); e != 0 {
+ return e
+ }
+ return nil
+}
+
+// HaveProcessMemoryBarrier returns true if ProcessMemoryBarrier is supported.
+func HaveProcessMemoryBarrier() bool {
+ return haveMembarrierPrivateExpedited
+}
+
+// ProcessMemoryBarrier is equivalent to GlobalMemoryBarrier, but only
+// synchronizes with threads sharing a virtual address space (from the host OS'
+// perspective) with the calling thread.
+//
+// Preconditions: HaveProcessMemoryBarrier() == true.
+func ProcessMemoryBarrier() error {
+ if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 {
+ return e
+ }
+ return nil
+}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 083071b5e..5de70aecb 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -204,7 +204,6 @@ go_library(
"//pkg/abi",
"//pkg/abi/linux",
"//pkg/amutex",
- "//pkg/binary",
"//pkg/bits",
"//pkg/bpf",
"//pkg/context",
diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go
index d3e76ca7b..060c056df 100644
--- a/pkg/sentry/kernel/kcov.go
+++ b/pkg/sentry/kernel/kcov.go
@@ -89,7 +89,7 @@ func (kcov *Kcov) TaskWork(t *Task) {
kcov.mu.Lock()
defer kcov.mu.Unlock()
- if kcov.mode != linux.KCOV_TRACE_PC {
+ if kcov.mode != linux.KCOV_MODE_TRACE_PC {
return
}
@@ -146,7 +146,7 @@ func (kcov *Kcov) InitTrace(size uint64) error {
}
// EnableTrace performs the KCOV_ENABLE_TRACE ioctl.
-func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
+func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error {
t := TaskFromContext(ctx)
if t == nil {
panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine")
@@ -160,9 +160,9 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
return syserror.EINVAL
}
- switch traceMode {
+ switch traceKind {
case linux.KCOV_TRACE_PC:
- kcov.mode = traceMode
+ kcov.mode = linux.KCOV_MODE_TRACE_PC
case linux.KCOV_TRACE_CMP:
// We do not support KCOV_MODE_TRACE_CMP.
return syserror.ENOTSUP
@@ -175,6 +175,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error {
}
kcov.owningTask = t
+ t.SetKcov(kcov)
t.RegisterWork(kcov)
// Clear existing coverage data; the task expects to read only coverage data
@@ -196,26 +197,35 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error {
if t != kcov.owningTask {
return syserror.EINVAL
}
- kcov.owningTask = nil
kcov.mode = linux.KCOV_MODE_INIT
- kcov.resetLocked()
+ kcov.owningTask = nil
+ kcov.mappable = nil
return nil
}
-// Reset is called when the owning task exits.
-func (kcov *Kcov) Reset() {
+// Clear resets the mode and clears the owning task and memory mapping for kcov.
+// It is called when the fd corresponding to kcov is closed. Note that the mode
+// needs to be set so that the next call to kcov.TaskWork() will exit early.
+func (kcov *Kcov) Clear() {
kcov.mu.Lock()
- kcov.resetLocked()
+ kcov.clearLocked()
kcov.mu.Unlock()
}
-// The kcov instance is reset when the owning task exits or when tracing is
-// disabled.
-func (kcov *Kcov) resetLocked() {
+func (kcov *Kcov) clearLocked() {
+ kcov.mode = linux.KCOV_MODE_INIT
kcov.owningTask = nil
- if kcov.mappable != nil {
- kcov.mappable = nil
- }
+ kcov.mappable = nil
+}
+
+// OnTaskExit is called when the owning task exits. It is similar to
+// kcov.Clear(), except the memory mapping is not cleared, so that the same
+// mapping can be used in the future if kcov is enabled again by another task.
+func (kcov *Kcov) OnTaskExit() {
+ kcov.mu.Lock()
+ kcov.mode = linux.KCOV_MODE_INIT
+ kcov.owningTask = nil
+ kcov.mu.Unlock()
}
// ConfigureMMap is called by the vfs.FileDescription for this kcov instance to
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index d6c21adb7..16c427fc8 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -1738,3 +1738,18 @@ func (k *Kernel) ShmMount() *vfs.Mount {
func (k *Kernel) SocketMount() *vfs.Mount {
return k.socketMount
}
+
+// Release releases resources owned by k.
+//
+// Precondition: This should only be called after the kernel is fully
+// initialized, e.g. after k.Start() has been called.
+func (k *Kernel) Release() {
+ if VFS2Enabled {
+ ctx := k.SupervisorContext()
+ k.hostMount.DecRef(ctx)
+ k.pipeMount.DecRef(ctx)
+ k.shmMount.DecRef(ctx)
+ k.socketMount.DecRef(ctx)
+ k.vfs.Release(ctx)
+ }
+}
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index c38c5a40c..387edfa91 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -18,7 +18,6 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
@@ -27,25 +26,18 @@ import (
const maxSyscallFilterInstructions = 1 << 15
-// seccompData is equivalent to struct seccomp_data, which contains the data
-// passed to seccomp-bpf filters.
-type seccompData struct {
- // nr is the system call number.
- nr int32
-
- // arch is an AUDIT_ARCH_* value indicating the system call convention.
- arch uint32
-
- // instructionPointer is the value of the instruction pointer at the time
- // of the system call.
- instructionPointer uint64
-
- // args contains the first 6 system call arguments.
- args [6]uint64
-}
-
-func (d *seccompData) asBPFInput() bpf.Input {
- return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder}
+// dataAsBPFInput returns a serialized BPF program, only valid on the current task
+// goroutine.
+//
+// Note: this is called for every syscall, which is a very hot path.
+func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input {
+ buf := t.CopyScratchBuffer(d.SizeBytes())
+ d.MarshalUnsafe(buf)
+ return bpf.InputBytes{
+ Data: buf,
+ // Go-marshal always uses the native byte order.
+ Order: usermem.ByteOrder,
+ }
}
func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo {
@@ -112,20 +104,20 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u
}
func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 {
- data := seccompData{
- nr: sysno,
- arch: t.tc.st.AuditNumber,
- instructionPointer: uint64(ip),
+ data := linux.SeccompData{
+ Nr: sysno,
+ Arch: t.tc.st.AuditNumber,
+ InstructionPointer: uint64(ip),
}
// data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so
// we can't do any slicing tricks or even use copy/append here.
for i, arg := range args {
- if i >= len(data.args) {
+ if i >= len(data.Args) {
break
}
- data.args[i] = arg.Uint64()
+ data.Args[i] = arg.Uint64()
}
- input := data.asBPFInput()
+ input := dataAsBPFInput(t, &data)
ret := uint32(linux.SECCOMP_RET_ALLOW)
f := t.syscallFilters.Load()
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 3eb78e91b..76d472292 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -8,7 +8,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index b07e1c1bd..78f718cfe 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -17,7 +17,6 @@ package signalfd
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
@@ -103,8 +102,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
// Copy out the signal info using the specified format.
- var buf [128]byte
- binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ infoNative := linux.SignalfdSiginfo{
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
@@ -113,9 +111,13 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
- })
- n, err := dst.CopyOut(ctx, buf[:])
- return int64(n), err
+ }
+ n, err := infoNative.WriteTo(dst.Writer(ctx))
+ if err == usermem.ErrEndOfIOSequence {
+ // Partial copy-out ok.
+ err = nil
+ }
+ return n, err
}
// Readiness implements waiter.Waitable.Readiness.
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index a436610c9..f796e0fa3 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -917,7 +917,7 @@ func (t *Task) SetKcov(k *Kcov) {
// ResetKcov clears the kcov instance associated with t.
func (t *Task) ResetKcov() {
if t.kcov != nil {
- t.kcov.Reset()
+ t.kcov.OnTaskExit()
t.kcov = nil
}
}
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 5ae5906e8..fdadb52c0 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -265,6 +265,13 @@ func (ns *PIDNamespace) Tasks() []*Task {
return tasks
}
+// NumTasks returns the number of tasks in ns.
+func (ns *PIDNamespace) NumTasks() int {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return len(ns.tids)
+}
+
// ThreadGroups returns a snapshot of the thread groups in ns.
func (ns *PIDNamespace) ThreadGroups() []*ThreadGroup {
return ns.ThreadGroupsAppend(nil)
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index e44a139b3..9bc452e67 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -17,7 +17,6 @@ package kernel
import (
"fmt"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -28,6 +27,8 @@ import (
//
// They are exposed to the VDSO via a parameter page managed by VDSOParamPage,
// which also includes a sequence counter.
+//
+// +marshal
type vdsoParams struct {
monotonicReady uint64
monotonicBaseCycles int64
@@ -68,6 +69,13 @@ type VDSOParamPage struct {
// checked in state_test_util tests, causing this field to change across
// save / restore.
seq uint64
+
+ // copyScratchBuffer is a temporary buffer used to marshal the params before
+ // copying it to the real parameter page. The parameter page is typically
+ // updated at a moderate frequency of ~O(seconds) throughout the lifetime of
+ // the sentry, so reusing this buffer is a good tradeoff between memory
+ // usage and the cost of allocation.
+ copyScratchBuffer []byte
}
// NewVDSOParamPage returns a VDSOParamPage.
@@ -79,7 +87,11 @@ type VDSOParamPage struct {
// * VDSOParamPage must be the only writer to fr.
// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage {
- return &VDSOParamPage{mfp: mfp, fr: fr}
+ return &VDSOParamPage{
+ mfp: mfp,
+ fr: fr,
+ copyScratchBuffer: make([]byte, (*vdsoParams)(nil).SizeBytes()),
+ }
}
// access returns a mapping of the param page.
@@ -133,7 +145,8 @@ func (v *VDSOParamPage) Write(f func() vdsoParams) error {
// Get the new params.
p := f()
- buf := binary.Marshal(nil, usermem.ByteOrder, p)
+ buf := v.copyScratchBuffer[:p.SizeBytes()]
+ p.MarshalUnsafe(buf)
// Skip the sequence counter.
if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil {
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index a44fa2b95..7fd77925f 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -127,7 +127,7 @@ func (t Translation) FileRange() FileRange {
// Preconditions: Same as Mappable.Translate.
func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error {
// Verify that the inputs to Mappable.Translate were valid.
- if !required.WellFormed() || required.Length() <= 0 {
+ if !required.WellFormed() || required.Length() == 0 {
panic(fmt.Sprintf("invalid required range: %v", required))
}
if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() {
@@ -145,7 +145,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
return fmt.Errorf("first Translation %+v does not cover start of required range %v", ts[0], required)
}
for i, t := range ts {
- if !t.Source.WellFormed() || t.Source.Length() <= 0 {
+ if !t.Source.WellFormed() || t.Source.Length() == 0 {
return fmt.Errorf("Translation %+v has invalid Source", t)
}
if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() {
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 8c9f11cce..92cc87d84 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -235,6 +235,20 @@ type MemoryManager struct {
// vdsoSigReturnAddr is the address of 'vdso_sigreturn'.
vdsoSigReturnAddr uint64
+
+ // membarrierPrivateEnabled is non-zero if EnableMembarrierPrivate has
+ // previously been called. Since, as of this writing,
+ // MEMBARRIER_CMD_PRIVATE_EXPEDITED is implemented as a global memory
+ // barrier, membarrierPrivateEnabled has no other effect.
+ //
+ // membarrierPrivateEnabled is accessed using atomic memory operations.
+ membarrierPrivateEnabled uint32
+
+ // membarrierRSeqEnabled is non-zero if EnableMembarrierRSeq has previously
+ // been called.
+ //
+ // membarrierRSeqEnabled is accessed using atomic memory operations.
+ membarrierRSeqEnabled uint32
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 30facebf7..7e5f7de64 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -36,7 +36,7 @@ import (
// * ar.Length() != 0.
func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -100,7 +100,7 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user
// (i.e. permission checks must have been performed against vmas).
func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Ok() {
@@ -193,7 +193,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR
// getVecPMAsLocked; other clients should call one of those instead.
func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Ok() {
@@ -223,7 +223,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter
// Need a pma here.
optAR := vseg.Range().Intersect(pgap.Range())
if checkInvariants {
- if optAR.Length() <= 0 {
+ if optAR.Length() == 0 {
panic(fmt.Sprintf("vseg %v and pgap %v do not overlap", vseg, pgap))
}
}
@@ -560,7 +560,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat
// Invalidate implements memmap.MappingSpace.Invalidate.
func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -583,7 +583,7 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate
// * ar must be page-aligned.
func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -629,7 +629,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
// * ar must be page-aligned.
func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -715,10 +715,10 @@ func Unpin(prs []PinnedRange) {
// * oldAR and newAR must be page-aligned.
func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
if checkInvariants {
- if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() {
+ if !oldAR.WellFormed() || oldAR.Length() == 0 || !oldAR.IsPageAligned() {
panic(fmt.Sprintf("invalid oldAR: %v", oldAR))
}
- if !newAR.WellFormed() || newAR.Length() <= 0 || !newAR.IsPageAligned() {
+ if !newAR.WellFormed() || newAR.Length() == 0 || !newAR.IsPageAligned() {
panic(fmt.Sprintf("invalid newAR: %v", newAR))
}
if oldAR.Length() > newAR.Length() {
@@ -778,7 +778,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) {
// into mm.pmas.
func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().Contains(ar.Start) {
@@ -831,7 +831,7 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe
// * pseg.Range().Contains(ar.Start).
func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().Contains(ar.Start) {
@@ -1050,7 +1050,7 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
if !pseg.Ok() {
panic("terminal pma iterator")
}
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !pseg.Range().IsSupersetOf(ar) {
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index a2555ba1a..675efdc7c 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -17,6 +17,7 @@ package mm
import (
"fmt"
mrand "math/rand"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -1274,3 +1275,27 @@ func (mm *MemoryManager) VirtualDataSize() uint64 {
defer mm.mappingMu.RUnlock()
return mm.dataAS
}
+
+// EnableMembarrierPrivate causes future calls to IsMembarrierPrivateEnabled to
+// return true.
+func (mm *MemoryManager) EnableMembarrierPrivate() {
+ atomic.StoreUint32(&mm.membarrierPrivateEnabled, 1)
+}
+
+// IsMembarrierPrivateEnabled returns true if mm.EnableMembarrierPrivate() has
+// previously been called.
+func (mm *MemoryManager) IsMembarrierPrivateEnabled() bool {
+ return atomic.LoadUint32(&mm.membarrierPrivateEnabled) != 0
+}
+
+// EnableMembarrierRSeq causes future calls to IsMembarrierRSeqEnabled to
+// return true.
+func (mm *MemoryManager) EnableMembarrierRSeq() {
+ atomic.StoreUint32(&mm.membarrierRSeqEnabled, 1)
+}
+
+// IsMembarrierRSeqEnabled returns true if mm.EnableMembarrierRSeq() has
+// previously been called.
+func (mm *MemoryManager) IsMembarrierRSeqEnabled() bool {
+ return atomic.LoadUint32(&mm.membarrierRSeqEnabled) != 0
+}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index f769d8294..b8df72813 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -266,7 +266,7 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 {
// * ar.Length() != 0.
func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -350,7 +350,7 @@ const guardBytes = 256 * usermem.PageSize
// * ar must be page-aligned.
func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -371,7 +371,7 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange)
// * ar must be page-aligned.
func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator {
if checkInvariants {
- if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() {
+ if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
}
@@ -511,7 +511,7 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan
if vseg.ValuePtr().mappable == nil {
panic("MappableRange is meaningless for anonymous vma")
}
- if !ar.WellFormed() || ar.Length() <= 0 {
+ if !ar.WellFormed() || ar.Length() == 0 {
panic(fmt.Sprintf("invalid ar: %v", ar))
}
if !vseg.Range().IsSupersetOf(ar) {
@@ -536,7 +536,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange {
if vseg.ValuePtr().mappable == nil {
panic("MappableRange is meaningless for anonymous vma")
}
- if !mr.WellFormed() || mr.Length() <= 0 {
+ if !mr.WellFormed() || mr.Length() == 0 {
panic(fmt.Sprintf("invalid mr: %v", mr))
}
if !vseg.mappableRange().IsSupersetOf(mr) {
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 209b28053..db7d55ef2 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -15,6 +15,7 @@ go_library(
"//pkg/context",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/hostmm",
"//pkg/sentry/memmap",
"//pkg/usermem",
],
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 3970dd81d..dd2bbeb12 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -9,12 +9,12 @@ go_library(
"bluepill.go",
"bluepill_allocator.go",
"bluepill_amd64.go",
- "bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
"bluepill_arm64.go",
"bluepill_arm64.s",
"bluepill_arm64_unsafe.go",
"bluepill_fault.go",
+ "bluepill_impl_amd64.s",
"bluepill_unsafe.go",
"context.go",
"filters_amd64.go",
@@ -56,6 +56,7 @@ go_library(
"//pkg/sentry/time",
"//pkg/sync",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
@@ -78,6 +79,15 @@ go_test(
"//pkg/sentry/platform/kvm/testutil",
"//pkg/sentry/platform/ring0",
"//pkg/sentry/platform/ring0/pagetables",
+ "//pkg/sentry/time",
"//pkg/usermem",
],
)
+
+genrule(
+ name = "bluepill_impl_amd64",
+ srcs = ["bluepill_amd64.s"],
+ outs = ["bluepill_impl_amd64.s"],
+ cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s
index 2bc34a435..025ea93b5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.s
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.s
@@ -19,11 +19,6 @@
// This is guaranteed to be zero.
#define VCPU_CPU 0x0
-// CPU_SELF is the self reference in ring0's percpu.
-//
-// This is guaranteed to be zero.
-#define CPU_SELF 0x0
-
// Context offsets.
//
// Only limited use of the context is done in the assembly stub below, most is
@@ -44,7 +39,7 @@ begin:
LEAQ VCPU_CPU(AX), BX
BYTE CLI;
check_vcpu:
- MOVQ CPU_SELF(GS), CX
+ MOVQ ENTRY_CPU_SELF(GS), CX
CMPQ BX, CX
JE right_vCPU
wrong_vcpu:
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 03a98512e..0a54dd30d 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -83,5 +83,34 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillReadyStopGuest(c *vCPU) bool {
- return c.runData.readyForInterruptInjection != 0
+ if c.runData.readyForInterruptInjection == 0 {
+ return false
+ }
+
+ if c.runData.ifFlag == 0 {
+ // This is impossible if readyForInterruptInjection is 1.
+ throw("interrupts are disabled")
+ }
+
+ // Disable interrupts if we are in the kernel space.
+ //
+ // When the Sentry switches into the kernel mode, it disables
+ // interrupts. But when goruntime switches on a goroutine which has
+ // been saved in the host mode, it restores flags and this enables
+ // interrupts. See the comment of UserFlagsSet for more details.
+ uregs := userRegs{}
+ err := c.getUserRegisters(&uregs)
+ if err != 0 {
+ throw("failed to get user registers")
+ }
+
+ if ring0.IsKernelFlags(uregs.RFLAGS) {
+ uregs.RFLAGS &^= ring0.KernelFlagsClear
+ err = c.setUserRegisters(&uregs)
+ if err != 0 {
+ throw("failed to set user registers")
+ }
+ return false
+ }
+ return true
}
diff --git a/pkg/sentry/platform/kvm/filters_amd64.go b/pkg/sentry/platform/kvm/filters_amd64.go
index 7d949f1dd..d3d216aa5 100644
--- a/pkg/sentry/platform/kvm/filters_amd64.go
+++ b/pkg/sentry/platform/kvm/filters_amd64.go
@@ -17,14 +17,23 @@ package kvm
import (
"syscall"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
// SyscallFilters returns syscalls made exclusively by the KVM platform.
func (*KVM) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_ARCH_PRCTL: {},
- syscall.SYS_IOCTL: {},
+ syscall.SYS_ARCH_PRCTL: {},
+ syscall.SYS_IOCTL: {},
+ unix.SYS_MEMBARRIER: []seccomp.Rule{
+ {
+ seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_MMAP: {},
syscall.SYS_RT_SIGSUSPEND: {},
syscall.SYS_RT_SIGTIMEDWAIT: {},
diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go
index 9245d07c2..21abc2a3d 100644
--- a/pkg/sentry/platform/kvm/filters_arm64.go
+++ b/pkg/sentry/platform/kvm/filters_arm64.go
@@ -17,13 +17,22 @@ package kvm
import (
"syscall"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
// SyscallFilters returns syscalls made exclusively by the KVM platform.
func (*KVM) SyscallFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_IOCTL: {},
+ syscall.SYS_IOCTL: {},
+ unix.SYS_MEMBARRIER: []seccomp.Rule{
+ {
+ seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED),
+ seccomp.EqualTo(0),
+ },
+ },
syscall.SYS_MMAP: {},
syscall.SYS_RT_SIGSUSPEND: {},
syscall.SYS_RT_SIGTIMEDWAIT: {},
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index ae813e24e..dd45ad10b 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -63,6 +63,9 @@ type runData struct {
type KVM struct {
platform.NoCPUPreemptionDetection
+ // KVM never changes mm_structs.
+ platform.UseHostProcessMemoryBarrier
+
// machine is the backing VM.
machine *machine
}
@@ -156,15 +159,7 @@ func (*KVM) MaxUserAddress() usermem.Addr {
func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
// Allocate page tables and install system mappings.
pageTables := pagetables.New(newAllocator())
- applyPhysicalRegions(func(pr physicalRegion) bool {
- // Map the kernel in the upper half.
- pageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
- return true // Keep iterating.
- })
+ k.machine.mapUpperHalf(pageTables)
// Return the new address space.
return &addressSpace{
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 5c4b18899..6abaa21c4 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -26,12 +26,16 @@ const (
_KVM_RUN = 0xae80
_KVM_NMI = 0xae9a
_KVM_CHECK_EXTENSION = 0xae03
+ _KVM_GET_TSC_KHZ = 0xaea3
+ _KVM_SET_TSC_KHZ = 0xaea2
_KVM_INTERRUPT = 0x4004ae86
_KVM_SET_MSRS = 0x4008ae89
_KVM_SET_USER_MEMORY_REGION = 0x4020ae46
_KVM_SET_REGS = 0x4090ae82
_KVM_SET_SREGS = 0x4138ae84
+ _KVM_GET_MSRS = 0xc008ae88
_KVM_GET_REGS = 0x8090ae81
+ _KVM_GET_SREGS = 0x8138ae83
_KVM_GET_SUPPORTED_CPUID = 0xc008ae05
_KVM_SET_CPUID2 = 0x4008ae90
_KVM_SET_SIGNAL_MASK = 0x4004ae8b
@@ -79,11 +83,14 @@ const (
)
// KVM hypercall list.
+//
// Canonical list of hypercalls supported.
const (
// On amd64, it uses 'HLT' to leave the guest.
+ //
// Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest.
- // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now.
+ //
+ // _KVM_HYPERCALL_VMEXIT is only used on arm64 for now.
_KVM_HYPERCALL_VMEXIT int = iota
_KVM_HYPERCALL_MAX
)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 9a7be3655..84df0f878 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -101,13 +101,20 @@ const (
// Arm64: Memory Attribute Indirection Register EL1.
const (
- _MT_DEVICE_nGnRnE = 0
- _MT_DEVICE_nGnRE = 1
- _MT_DEVICE_GRE = 2
- _MT_NORMAL_NC = 3
- _MT_NORMAL = 4
- _MT_NORMAL_WT = 5
- _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8)
+ _MT_DEVICE_nGnRnE = 0
+ _MT_DEVICE_nGnRE = 1
+ _MT_DEVICE_GRE = 2
+ _MT_NORMAL_NC = 3
+ _MT_NORMAL = 4
+ _MT_NORMAL_WT = 5
+ _MT_ATTR_DEVICE_nGnRnE = 0x00
+ _MT_ATTR_DEVICE_nGnRE = 0x04
+ _MT_ATTR_DEVICE_GRE = 0x0c
+ _MT_ATTR_NORMAL_NC = 0x44
+ _MT_ATTR_NORMAL_WT = 0xbb
+ _MT_ATTR_NORMAL = 0xff
+ _MT_ATTR_MASK = 0xff
+ _MT_EL1_INIT = (_MT_ATTR_DEVICE_nGnRnE << (_MT_DEVICE_nGnRnE * 8)) | (_MT_ATTR_DEVICE_nGnRE << (_MT_DEVICE_nGnRE * 8)) | (_MT_ATTR_DEVICE_GRE << (_MT_DEVICE_GRE * 8)) | (_MT_ATTR_NORMAL_NC << (_MT_NORMAL_NC * 8)) | (_MT_ATTR_NORMAL << (_MT_NORMAL * 8)) | (_MT_ATTR_NORMAL_WT << (_MT_NORMAL_WT * 8))
)
const (
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 45b3180f1..2e12470aa 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -442,6 +443,22 @@ func TestWrongVCPU(t *testing.T) {
})
}
+func TestRdtsc(t *testing.T) {
+ var i int // Iteration count.
+ kvmTest(t, nil, func(c *vCPU) bool {
+ start := ktime.Rdtsc()
+ bluepill(c)
+ guest := ktime.Rdtsc()
+ redpill()
+ end := ktime.Rdtsc()
+ if start > guest || guest > end {
+ t.Errorf("inconsistent time: start=%d, guest=%d, end=%d", start, guest, end)
+ }
+ i++
+ return i < 100
+ })
+}
+
func BenchmarkApplicationSyscall(b *testing.B) {
var (
i int // Iteration includes machine.Get() / machine.Put().
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index a74930423..455e2bd20 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -155,7 +155,7 @@ func (m *machine) newVCPU() *vCPU {
fd: int(fd),
machine: m,
}
- c.CPU.Init(&m.kernel, c)
+ c.CPU.Init(&m.kernel, c.id, c)
m.vCPUsByID[c.id] = c
// Ensure the signal mask is correct.
@@ -183,9 +183,6 @@ func newMachine(vm int) (*machine, error) {
// Create the machine.
m := &machine{fd: vm}
m.available.L = &m.mu
- m.kernel.Init(ring0.KernelOpts{
- PageTables: pagetables.New(newAllocator()),
- })
// Pull the maximum vCPUs.
maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS)
@@ -197,6 +194,9 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+ m.kernel.Init(ring0.KernelOpts{
+ PageTables: pagetables.New(newAllocator()),
+ }, m.maxVCPUs)
// Pull the maximum slots.
maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS)
@@ -219,15 +219,9 @@ func newMachine(vm int) (*machine, error) {
pagetables.MapOpts{AccessType: usermem.AnyAccess},
pr.physical)
- // And keep everything in the upper half.
- m.kernel.PageTables.Map(
- usermem.Addr(ring0.KernelStartAddress|pr.virtual),
- pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
- pr.physical)
-
return true // Keep iterating.
})
+ m.mapUpperHalf(m.kernel.PageTables)
var physicalRegionsReadOnly []physicalRegion
var physicalRegionsAvailable []physicalRegion
@@ -365,6 +359,11 @@ func (m *machine) Destroy() {
// Get gets an available vCPU.
//
// This will return with the OS thread locked.
+//
+// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points
+// to the vCPU in which the OS thread TID is running. So if Get() returns with
+// the corrent context in guest, the vCPU of it must be the same as what
+// Get() returns.
func (m *machine) Get() *vCPU {
m.mu.RLock()
runtime.LockOSThread()
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index ccaf3a028..c67127d95 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -18,14 +18,17 @@ package kvm
import (
"fmt"
+ "math/big"
"reflect"
"runtime/debug"
"syscall"
+ "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -131,6 +134,7 @@ func (c *vCPU) initArchState() error {
// Set the entrypoint for the kernel.
kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer())
kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ kernelUserRegs.RSP = c.StackTop()
kernelUserRegs.RFLAGS = ring0.KernelFlagsSet
// Set the system registers.
@@ -139,8 +143,8 @@ func (c *vCPU) initArchState() error {
}
// Set the user registers.
- if err := c.setUserRegisters(&kernelUserRegs); err != nil {
- return err
+ if errno := c.setUserRegisters(&kernelUserRegs); errno != 0 {
+ return fmt.Errorf("error setting user registers: %v", errno)
}
// Allocate some floating point state save area for the local vCPU.
@@ -153,6 +157,133 @@ func (c *vCPU) initArchState() error {
return c.setSystemTime()
}
+// bitsForScaling returns the bits available for storing the fraction component
+// of the TSC scaling ratio. This allows us to replicate the (bad) math done by
+// the kernel below in scaledTSC, and ensure we can compute an exact zero
+// offset in setSystemTime.
+//
+// These constants correspond to kvm_tsc_scaling_ratio_frac_bits.
+var bitsForScaling = func() int64 {
+ fs := cpuid.HostFeatureSet()
+ if fs.Intel() {
+ return 48 // See vmx.c (kvm sources).
+ } else if fs.AMD() {
+ return 32 // See svm.c (svm sources).
+ } else {
+ return 63 // Unknown: theoretical maximum.
+ }
+}()
+
+// scaledTSC returns the host TSC scaled by the given frequency.
+//
+// This assumes a current frequency of 1. We require only the unitless ratio of
+// rawFreq to some current frequency. See setSystemTime for context.
+//
+// The kernel math guarantees that all bits of the multiplication and division
+// will be correctly preserved and applied. However, it is not possible to
+// actually store the ratio correctly. So we need to use the same schema in
+// order to calculate the scaled frequency and get the same result.
+//
+// We can assume that the current frequency is (1), so we are calculating a
+// strict inverse of this value. This simplifies this function considerably.
+//
+// Roughly, the returned value "scaledTSC" will have:
+// scaledTSC/hostTSC == 1/rawFreq
+//
+//go:nosplit
+func scaledTSC(rawFreq uintptr) int64 {
+ scale := int64(1 << bitsForScaling)
+ ratio := big.NewInt(scale / int64(rawFreq))
+ ratio.Mul(ratio, big.NewInt(int64(ktime.Rdtsc())))
+ ratio.Div(ratio, big.NewInt(scale))
+ return ratio.Int64()
+}
+
+// setSystemTime sets the vCPU to the system time.
+func (c *vCPU) setSystemTime() error {
+ // First, scale down the clock frequency to the lowest value allowed by
+ // the API itself. How low we can go depends on the underlying
+ // hardware, but it is typically ~1/2^48 for Intel, ~1/2^32 for AMD.
+ // Even the lower bound here will take a 4GHz frequency down to 1Hz,
+ // meaning that everything should be able to handle a Khz setting of 1
+ // with bits to spare.
+ //
+ // Note that reducing the clock does not typically require special
+ // capabilities as it is emulated in KVM. We don't actually use this
+ // capability, but it means that this method should be robust to
+ // different hardware configurations.
+ rawFreq, err := c.getTSCFreq()
+ if err != nil {
+ return c.setSystemTimeLegacy()
+ }
+ if err := c.setTSCFreq(1); err != nil {
+ return c.setSystemTimeLegacy()
+ }
+
+ // Always restore the original frequency.
+ defer func() {
+ if err := c.setTSCFreq(rawFreq); err != nil {
+ panic(err.Error())
+ }
+ }()
+
+ // Attempt to set the system time in this compressed world. The
+ // calculation for offset normally looks like:
+ //
+ // offset = target_tsc - kvm_scale_tsc(vcpu, rdtsc());
+ //
+ // So as long as the kvm_scale_tsc component is constant before and
+ // after the call to set the TSC value (and it is passes as the
+ // target_tsc), we will compute an offset value of zero.
+ //
+ // This is effectively cheating to make our "setSystemTime" call so
+ // unbelievably, incredibly fast that we do it "instantly" and all the
+ // calculations result in an offset of zero.
+ lastTSC := scaledTSC(rawFreq)
+ for {
+ if err := c.setTSC(uint64(lastTSC)); err != nil {
+ return err
+ }
+ nextTSC := scaledTSC(rawFreq)
+ if lastTSC == nextTSC {
+ return nil
+ }
+ lastTSC = nextTSC // Try again.
+ }
+}
+
+// setSystemTimeLegacy calibrates and sets an approximate system time.
+func (c *vCPU) setSystemTimeLegacy() error {
+ const minIterations = 10
+ minimum := uint64(0)
+ for iter := 0; ; iter++ {
+ // Try to set the TSC to an estimate of where it will be
+ // on the host during a "fast" system call iteration.
+ start := uint64(ktime.Rdtsc())
+ if err := c.setTSC(start + (minimum / 2)); err != nil {
+ return err
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result.
+ end := uint64(ktime.Rdtsc())
+ if end < start {
+ continue // Totally bogus: unstable TSC?
+ }
+ current := end - start
+ if current < minimum || iter == 0 {
+ minimum = current // Set our new minimum.
+ }
+ // Is this past minIterations and within ~10% of minimum?
+ upperThreshold := (((minimum << 3) + minimum) >> 3)
+ if iter >= minIterations && current <= upperThreshold {
+ return nil
+ }
+ }
+}
+
// nonCanonical generates a canonical address return.
//
//go:nosplit
@@ -332,3 +463,41 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return physicalRegions
}
+
+var execRegions = func() (regions []region) {
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
+ return
+ }
+ if vr.accessType.Execute {
+ regions = append(regions, vr.region)
+ }
+ })
+ return
+}()
+
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ for _, r := range execRegions {
+ physical, length, ok := translateToPhysical(r.virtual)
+ if !ok || length < r.length {
+ panic("impossilbe translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ r.length,
+ pagetables.MapOpts{AccessType: usermem.Execute},
+ physical)
+ }
+ for start, end := range m.kernel.EntryRegions() {
+ regionLen := end - start
+ physical, length, ok := translateToPhysical(start)
+ if !ok || length < regionLen {
+ panic("impossible translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|start),
+ regionLen,
+ pagetables.MapOpts{AccessType: usermem.ReadWrite},
+ physical)
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 290f035dd..b430f92c6 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -23,7 +23,6 @@ import (
"unsafe"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/time"
)
// loadSegments copies the current segments.
@@ -61,91 +60,63 @@ func (c *vCPU) setCPUID() error {
return nil
}
-// setSystemTime sets the TSC for the vCPU.
+// getTSCFreq gets the TSC frequency.
//
-// This has to make the call many times in order to minimize the intrinsic
-// error in the offset. Unfortunately KVM does not expose a relative offset via
-// the API, so this is an approximation. We do this via an iterative algorithm.
-// This has the advantage that it can generally deal with highly variable
-// system call times and should converge on the correct offset.
-func (c *vCPU) setSystemTime() error {
- const (
- _MSR_IA32_TSC = 0x00000010
- calibrateTries = 10
- )
- registers := modelControlRegisters{
- nmsrs: 1,
- }
- registers.entries[0] = modelControlRegister{
- index: _MSR_IA32_TSC,
+// If mustSucceed is true, then this function panics on error.
+func (c *vCPU) getTSCFreq() (uintptr, error) {
+ rawFreq, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_TSC_KHZ,
+ 0 /* ignored */)
+ if errno != 0 {
+ return 0, errno
}
- target := uint64(^uint32(0))
- for done := 0; done < calibrateTries; {
- start := uint64(time.Rdtsc())
- registers.entries[0].data = start + target
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_MSRS,
- uintptr(unsafe.Pointer(&registers))); errno != 0 {
- return fmt.Errorf("error setting system time: %v", errno)
- }
- // See if this is our new minimum call time. Note that this
- // serves two functions: one, we make sure that we are
- // accurately predicting the offset we need to set. Second, we
- // don't want to do the final set on a slow call, which could
- // produce a really bad result. So we only count attempts
- // within +/- 6.25% of our minimum as an attempt.
- end := uint64(time.Rdtsc())
- if end < start {
- continue // Totally bogus.
- }
- half := (end - start) / 2
- if half < target {
- target = half
- }
- if (half - target) < target/8 {
- done++
- }
+ return rawFreq, nil
+}
+
+// setTSCFreq sets the TSC frequency.
+func (c *vCPU) setTSCFreq(freq uintptr) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_TSC_KHZ,
+ freq /* khz */); errno != 0 {
+ return fmt.Errorf("error setting TSC frequency: %v", errno)
}
return nil
}
-// setSignalMask sets the vCPU signal mask.
-//
-// This must be called prior to running the vCPU.
-func (c *vCPU) setSignalMask() error {
- // The layout of this structure implies that it will not necessarily be
- // the same layout chosen by the Go compiler. It gets fudged here.
- var data struct {
- length uint32
- mask1 uint32
- mask2 uint32
- _ uint32
+// setTSC sets the TSC value.
+func (c *vCPU) setTSC(value uint64) error {
+ const _MSR_IA32_TSC = 0x00000010
+ registers := modelControlRegisters{
+ nmsrs: 1,
}
- data.length = 8 // Fixed sigset size.
- data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
- data.mask2 = ^uint32(bounceSignalMask >> 32)
+ registers.entries[0].index = _MSR_IA32_TSC
+ registers.entries[0].data = value
if _, _, errno := syscall.RawSyscall(
syscall.SYS_IOCTL,
uintptr(c.fd),
- _KVM_SET_SIGNAL_MASK,
- uintptr(unsafe.Pointer(&data))); errno != 0 {
- return fmt.Errorf("error setting signal mask: %v", errno)
+ _KVM_SET_MSRS,
+ uintptr(unsafe.Pointer(&registers))); errno != 0 {
+ return fmt.Errorf("error setting tsc: %v", errno)
}
return nil
}
// setUserRegisters sets user registers in the vCPU.
-func (c *vCPU) setUserRegisters(uregs *userRegs) error {
+//
+//go:nosplit
+func (c *vCPU) setUserRegisters(uregs *userRegs) syscall.Errno {
if _, _, errno := syscall.RawSyscall(
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_REGS,
uintptr(unsafe.Pointer(uregs))); errno != 0 {
- return fmt.Errorf("error setting user registers: %v", errno)
+ return errno
}
- return nil
+ return 0
}
// getUserRegisters reloads user registers in the vCPU.
@@ -175,3 +146,17 @@ func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
}
return nil
}
+
+// getSystemRegisters sets system registers.
+//
+//go:nosplit
+func (c *vCPU) getSystemRegisters(sregs *systemRegs) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_SREGS,
+ uintptr(unsafe.Pointer(sregs))); errno != 0 {
+ return errno
+ }
+ return 0
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 1d247f0dd..54837f20c 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -19,6 +19,7 @@ package kvm
import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -48,6 +49,18 @@ const (
poolPCIDs = 8
)
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ applyPhysicalRegions(func(pr physicalRegion) bool {
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|pr.virtual),
+ pr.length,
+ pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pr.physical)
+
+ return true // Keep iterating.
+ })
+}
+
// Get all read-only physicalRegions.
func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
var rdonlyRegions []region
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 537419657..a163f956d 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -191,42 +191,6 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
return nil
}
-// setCPUID sets the CPUID to be used by the guest.
-func (c *vCPU) setCPUID() error {
- return nil
-}
-
-// setSystemTime sets the TSC for the vCPU.
-func (c *vCPU) setSystemTime() error {
- return nil
-}
-
-// setSignalMask sets the vCPU signal mask.
-//
-// This must be called prior to running the vCPU.
-func (c *vCPU) setSignalMask() error {
- // The layout of this structure implies that it will not necessarily be
- // the same layout chosen by the Go compiler. It gets fudged here.
- var data struct {
- length uint32
- mask1 uint32
- mask2 uint32
- _ uint32
- }
- data.length = 8 // Fixed sigset size.
- data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
- data.mask2 = ^uint32(bounceSignalMask >> 32)
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_SIGNAL_MASK,
- uintptr(unsafe.Pointer(&data))); errno != 0 {
- return fmt.Errorf("error setting signal mask: %v", errno)
- }
-
- return nil
-}
-
// SwitchToUser unpacks architectural-details.
func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
// Check for canonical addresses.
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index 607c82156..1d6ca245a 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -143,3 +143,29 @@ func (c *vCPU) waitUntilNot(state uint32) {
panic("futex wait error")
}
}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index 530e779b0..dcfe839a7 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/hostmm"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -52,6 +53,10 @@ type Platform interface {
// can reliably return ErrContextCPUPreempted.
DetectsCPUPreemption() bool
+ // HaveGlobalMemoryBarrier returns true if the GlobalMemoryBarrier method
+ // is supported.
+ HaveGlobalMemoryBarrier() bool
+
// MapUnit returns the alignment used for optional mappings into this
// platform's AddressSpaces. Higher values indicate lower per-page costs
// for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates
@@ -97,6 +102,15 @@ type Platform interface {
// called.
PreemptAllCPUs() error
+ // GlobalMemoryBarrier blocks until all threads running application code
+ // (via Context.Switch) and all task goroutines "have passed through a
+ // state where all memory accesses to user-space addresses match program
+ // order between entry to and return from [GlobalMemoryBarrier]", as for
+ // membarrier(2).
+ //
+ // Preconditions: HaveGlobalMemoryBarrier() == true.
+ GlobalMemoryBarrier() error
+
// SyscallFilters returns syscalls made exclusively by this platform.
SyscallFilters() seccomp.SyscallRules
}
@@ -115,6 +129,43 @@ func (NoCPUPreemptionDetection) PreemptAllCPUs() error {
panic("This platform does not support CPU preemption detection")
}
+// UseHostGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and
+// Platform.GlobalMemoryBarrier by invoking equivalent functionality on the
+// host.
+type UseHostGlobalMemoryBarrier struct{}
+
+// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier.
+func (UseHostGlobalMemoryBarrier) HaveGlobalMemoryBarrier() bool {
+ return hostmm.HaveGlobalMemoryBarrier()
+}
+
+// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier.
+func (UseHostGlobalMemoryBarrier) GlobalMemoryBarrier() error {
+ return hostmm.GlobalMemoryBarrier()
+}
+
+// UseHostProcessMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and
+// Platform.GlobalMemoryBarrier by invoking a process-local memory barrier.
+// This is faster than UseHostGlobalMemoryBarrier, but is only appropriate for
+// platforms for which application code executes while using the sentry's
+// mm_struct.
+type UseHostProcessMemoryBarrier struct{}
+
+// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier.
+func (UseHostProcessMemoryBarrier) HaveGlobalMemoryBarrier() bool {
+ // Fall back to a global memory barrier if a process-local one isn't
+ // available.
+ return hostmm.HaveProcessMemoryBarrier() || hostmm.HaveGlobalMemoryBarrier()
+}
+
+// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier.
+func (UseHostProcessMemoryBarrier) GlobalMemoryBarrier() error {
+ if hostmm.HaveProcessMemoryBarrier() {
+ return hostmm.ProcessMemoryBarrier()
+ }
+ return hostmm.GlobalMemoryBarrier()
+}
+
// MemoryManager represents an abstraction above the platform address space
// which manages memory mappings and their contents.
type MemoryManager interface {
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index b52d0fbd8..f56aa3b79 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -192,6 +192,7 @@ func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {}
type PTrace struct {
platform.MMapMinAddr
platform.NoCPUPreemptionDetection
+ platform.UseHostGlobalMemoryBarrier
}
// New returns a new ptrace-based implementation of the platform interface.
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 9c6c2cf5c..00899273e 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -76,15 +76,41 @@ type KernelOpts struct {
type KernelArchState struct {
KernelOpts
+ // cpuEntries is array of kernelEntry for all cpus
+ cpuEntries []kernelEntry
+
// globalIDT is our set of interrupt gates.
- globalIDT idt64
+ globalIDT *idt64
}
-// CPUArchState contains CPU-specific arch state.
-type CPUArchState struct {
+// kernelEntry contains minimal CPU-specific arch state
+// that can be mapped at the upper of the address space.
+// Malicious APP might steal info from it via CPU bugs.
+type kernelEntry struct {
// stack is the stack used for interrupts on this CPU.
stack [256]byte
+ // scratch space for temporary usage.
+ scratch0 uint64
+
+ // stackTop is the top of the stack.
+ stackTop uint64
+
+ // cpuSelf is back reference to CPU.
+ cpuSelf *CPU
+
+ // kernelCR3 is the cr3 used for sentry kernel.
+ kernelCR3 uintptr
+
+ // gdt is the CPU's descriptor table.
+ gdt descriptorTable
+
+ // tss is the CPU's task state.
+ tss TaskState64
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
// errorCode is the error code from the last exception.
errorCode uintptr
@@ -97,11 +123,7 @@ type CPUArchState struct {
// exception.
errorType uintptr
- // gdt is the CPU's descriptor table.
- gdt descriptorTable
-
- // tss is the CPU's task state.
- tss TaskState64
+ *kernelEntry
}
// ErrorCode returns the last error code.
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
index 7fa43c2f5..d87b1fd00 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.go
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -36,12 +36,15 @@ func sysenter()
// This must be called prior to sysret/iret.
func swapgs()
+// jumpToKernel jumps to the kernel version of the current RIP.
+func jumpToKernel()
+
// sysret returns to userspace from a system call.
//
// The return code is the vector that interrupted execution.
//
// See stubs.go for a note regarding the frame size of this function.
-func sysret(*CPU, *arch.Registers) Vector
+func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// "iret is the cadillac of CPL switching."
//
@@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector
// iret is nearly identical to sysret, except an iret is used to fully restore
// all user state. This must be called in cases where all registers need to be
// restored.
-func iret(*CPU, *arch.Registers) Vector
+func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector
// exception is the generic exception entry.
//
diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s
index 02df38331..f59747df3 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.s
+++ b/pkg/sentry/platform/ring0/entry_amd64.s
@@ -63,6 +63,15 @@
MOVQ offset+PTRACE_RSI(reg), SI; \
MOVQ offset+PTRACE_RDI(reg), DI;
+// WRITE_CR3() writes the given CR3 value.
+//
+// The code corresponds to:
+//
+// mov %rax, %cr3
+//
+#define WRITE_CR3() \
+ BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
+
// SWAP_GS swaps the kernel GS (CPU).
#define SWAP_GS() \
BYTE $0x0F; BYTE $0x01; BYTE $0xf8;
@@ -75,15 +84,9 @@
#define SYSRET64() \
BYTE $0x48; BYTE $0x0f; BYTE $0x07;
-// LOAD_KERNEL_ADDRESS loads a kernel address.
-#define LOAD_KERNEL_ADDRESS(from, to) \
- MOVQ from, to; \
- ORQ ·KernelStartAddress(SB), to;
-
// LOAD_KERNEL_STACK loads the kernel stack.
-#define LOAD_KERNEL_STACK(from) \
- LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \
- LEAQ CPU_STACK_TOP(SP), SP;
+#define LOAD_KERNEL_STACK(entry) \
+ MOVQ ENTRY_STACK_TOP(entry), SP;
// See kernel.go.
TEXT ·Halt(SB),NOSPLIT,$0
@@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0
SWAP_GS()
RET
+// jumpToKernel changes execution to the kernel address space.
+//
+// This works by changing the return value to the kernel version.
+TEXT ·jumpToKernel(SB),NOSPLIT,$0
+ MOVQ 0(SP), AX
+ ORQ ·KernelStartAddress(SB), AX // Future return value.
+ MOVQ AX, 0(SP)
+ RET
+
// See entry_amd64.go.
TEXT ·sysret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
+ // save SP AX userCR3 on the kernel stack.
+ MOVQ CPU_ENTRY(BX), BX
+ LOAD_KERNEL_STACK(BX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_RAX(AX)
+ PUSHQ CX
+
// Restore user register state.
REGISTERS_LOAD(AX, 0)
MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET.
MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET.
- MOVQ PTRACE_RSP(AX), SP // Restore the stack directly.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+
+ // restore userCR3, AX, SP.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
+ POPQ SP // Restore SP.
SYSRET64()
// See entry_amd64.go.
TEXT ·iret(SB),NOSPLIT,$0-24
- // Save original state.
- LOAD_KERNEL_ADDRESS(cpu+0(FP), BX)
- LOAD_KERNEL_ADDRESS(regs+8(FP), AX)
+ CALL ·jumpToKernel(SB)
+ // Save original state and stack. sysenter() or exception()
+ // from APP(gr3) will switch to this stack, set the return
+ // value (vector: 32(SP)) and then do RET, which will also
+ // automatically return to the lower half.
+ MOVQ cpu+0(FP), BX
+ MOVQ regs+8(FP), AX
+ MOVQ userCR3+16(FP), CX
MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX)
MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX)
MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX)
// Build an IRET frame & restore state.
+ MOVQ CPU_ENTRY(BX), BX
LOAD_KERNEL_STACK(BX)
- MOVQ PTRACE_SS(AX), BX; PUSHQ BX
- MOVQ PTRACE_RSP(AX), CX; PUSHQ CX
- MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX
- MOVQ PTRACE_CS(AX), DI; PUSHQ DI
- MOVQ PTRACE_RIP(AX), SI; PUSHQ SI
- REGISTERS_LOAD(AX, 0) // Restore most registers.
- MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch).
+ PUSHQ PTRACE_SS(AX)
+ PUSHQ PTRACE_RSP(AX)
+ PUSHQ PTRACE_FLAGS(AX)
+ PUSHQ PTRACE_CS(AX)
+ PUSHQ PTRACE_RIP(AX)
+ PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack.
+ PUSHQ CX // Save userCR3 on kernel stack.
+ REGISTERS_LOAD(AX, 0) // Restore most registers.
+ POPQ AX // Get userCR3.
+ WRITE_CR3() // Switch to userCR3.
+ POPQ AX // Restore AX.
IRET()
// See entry_amd64.go.
TEXT ·resume(SB),NOSPLIT,$0
// See iret, above.
- MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX
- MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX
- MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI
- MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI
- REGISTERS_LOAD(GS, CPU_REGISTERS)
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ PUSHQ CPU_REGISTERS+PTRACE_SS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RSP(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_CS(AX)
+ PUSHQ CPU_REGISTERS+PTRACE_RIP(AX)
+ REGISTERS_LOAD(AX, CPU_REGISTERS)
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX
IRET()
// See entry_amd64.go.
TEXT ·Start(SB),NOSPLIT,$0
- LOAD_KERNEL_STACK(AX) // Set the stack.
PUSHQ $0x0 // Previous frame pointer.
MOVQ SP, BP // Set frame pointer.
PUSHQ AX // First argument (CPU).
@@ -155,53 +193,60 @@ TEXT ·Start(SB),NOSPLIT,$0
// See entry_amd64.go.
TEXT ·sysenter(SB),NOSPLIT,$0
- // Interrupts are always disabled while we're executing in kernel mode
- // and always enabled while executing in user mode. Therefore, we can
- // reliably look at the flags in R11 to determine where this syscall
- // was from.
- TESTL $_RFLAGS_IF, R11
+ // _RFLAGS_IOPL0 is always set in the user mode and it is never set in
+ // the kernel mode. See the comment of UserFlagsSet for more details.
+ TESTL $_RFLAGS_IOPL0, R11
JZ kernel
-
user:
SWAP_GS()
- XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks.
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs).
+ MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value.
- MOVQ BX, PTRACE_RAX(AX) // Save everything else.
- MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ CX, PTRACE_RIP(AX)
MOVQ R11, PTRACE_FLAGS(AX)
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
+ MOVQ SP, PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value.
+ MOVQ CX, PTRACE_RAX(AX) // Save everything else.
+ MOVQ CX, PTRACE_ORIGRAX(AX)
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks.
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
// Return to the kernel, where the frame is:
//
- // vector (sp+24)
+ // vector (sp+32)
+ // userCR3 (sp+24)
// regs (sp+16)
// cpu (sp+8)
// vcpu.Switch (sp+0)
//
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ $Syscall, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ $Syscall, 32(SP) // Output vector.
RET
kernel:
// We can't restore the original stack, but we can access the registers
// in the CPU state directly. No need for temporary juggling.
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS)
- MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ AX, ENTRY_SCRATCH0(GS)
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX)
+ MOVQ ENTRY_SCRATCH0(GS), BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
// Call the syscall trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ AX // First argument (vCPU).
CALL ·kernelSyscall(SB) // Call the trampoline.
POPQ AX // Pop vCPU.
@@ -230,16 +275,21 @@ TEXT ·exception(SB),NOSPLIT,$0
// ERROR_CODE (sp+8)
// VECTOR (sp+0)
//
- TESTL $_RFLAGS_IF, 32(SP)
+ TESTL $_RFLAGS_IOPL0, 32(SP)
JZ kernel
user:
SWAP_GS()
ADDQ $-8, SP // Adjust for flags.
MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ).
- XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs.
+ PUSHQ AX // Save user AX on stack.
+ MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX.
+ WRITE_CR3() // Switch to kernel cr3.
+
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs.
REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX.
- MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX.
+ POPQ BX // Restore original AX.
MOVQ BX, PTRACE_RAX(AX) // Save it.
MOVQ BX, PTRACE_ORIGRAX(AX)
MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX)
@@ -249,34 +299,36 @@ user:
MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX)
// Copy out and return.
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
MOVQ 0(SP), BX // Load vector.
MOVQ 8(SP), CX // Load error code.
- MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version).
- MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer.
- MOVQ CX, CPU_ERROR_CODE(GS) // Set error code.
- MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user.
- MOVQ BX, 24(SP) // Output vector.
+ MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version).
+ MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer.
+ MOVQ CX, CPU_ERROR_CODE(AX) // Set error code.
+ MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user.
+ MOVQ BX, 32(SP) // Output vector.
RET
kernel:
// As per above, we can save directly.
- MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS)
- MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS)
- REGISTERS_SAVE(GS, CPU_REGISTERS)
- MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS)
- MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS)
- MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS)
+ PUSHQ AX
+ MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU.
+ REGISTERS_SAVE(AX, CPU_REGISTERS)
+ POPQ BX
+ MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX)
+ MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX)
+ MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX)
+ MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX)
+ MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX)
// Set the error code and adjust the stack.
- MOVQ 8(SP), AX // Load the error code.
- MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU.
- MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel.
+ MOVQ 8(SP), BX // Load the error code.
+ MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU.
+ MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel.
MOVQ 0(SP), BX // BX contains the vector.
- ADDQ $48, SP // Drop the exception frame.
// Call the exception trampoline.
LOAD_KERNEL_STACK(GS)
- MOVQ CPU_SELF(GS), AX // Load vCPU.
PUSHQ BX // Second argument (vector).
PUSHQ AX // First argument (vCPU).
CALL ·kernelException(SB) // Call the trampoline.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 5f63cbd45..494baaa4d 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -47,8 +47,9 @@
#define SCTLR_C 1 << 2
#define SCTLR_I 1 << 12
#define SCTLR_UCT 1 << 15
+#define SCTLR_UCI 1 << 26
-#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT)
+#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT | SCTLR_UCI)
// cntkctl_el1: counter-timer kernel control register el1.
#define CNTKCTL_EL0PCTEN 1 << 0
@@ -342,6 +343,8 @@
ADD $16, RSP, RSP; \
MOVD RSV_REG, PTRACE_R18(R20); \
MOVD RSV_REG_APP, PTRACE_R9(R20); \
+ MRS TPIDR_EL0, R3; \
+ MOVD R3, PTRACE_TLS(R20); \
WORD $0xd5384003; \ // MRS SPSR_EL1, R3
MOVD R3, PTRACE_PSTATE(R20); \
MRS ELR_EL1, R3; \
@@ -354,6 +357,8 @@
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context.
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
+ MRS TPIDR_EL0, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_TLS(RSV_REG); \
WORD $0xd5384004; \ // MRS SPSR_EL1, R4
MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
MRS ELR_EL1, R4; \
@@ -435,6 +440,8 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
MRS TPIDR_EL1, RSV_REG
REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+ MRS TPIDR_EL0, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_TLS(RSV_REG)
WORD $0xd5384003 // MRS SPSR_EL1, R3
MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
@@ -461,8 +468,18 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
MOVD PTRACE_PSTATE(RSV_REG_APP), R1
WORD $0xd5184001 //MSR R1, SPSR_EL1
+ // need use kernel space address to excute below code, since
+ // after SWITCH_TO_APP_PAGETABLE the ASID is changed to app's
+ // ASID.
+ WORD $0x10000061 // ADR R1, do_exit_to_el0
+ ORR $0xffff000000000000, R1, R1
+ JMP (R1)
+
+do_exit_to_el0:
// RSV_REG & RSV_REG_APP will be loaded at the end.
REGISTERS_LOAD(RSV_REG_APP, 0)
+ MOVD PTRACE_TLS(RSV_REG_APP), RSV_REG
+ MSR RSV_REG, TPIDR_EL0
// switch to user pagetable.
MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 549f3d228..9742308d8 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,7 +24,10 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
- visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
+ visibility = [
+ "//pkg/sentry/platform/kvm:__pkg__",
+ "//pkg/sentry/platform/ring0:__pkg__",
+ ],
deps = [
"//pkg/cpuid",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 021693791..264be23d3 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -19,8 +19,8 @@ package ring0
// N.B. that constraints on KernelOpts must be satisfied.
//
//go:nosplit
-func (k *Kernel) Init(opts KernelOpts) {
- k.init(opts)
+func (k *Kernel) Init(opts KernelOpts, maxCPUs int) {
+ k.init(opts, maxCPUs)
}
// Halt halts execution.
@@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) {
// kernelSyscall is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) {
// kernelException is a trampoline.
//
+// When in amd64, it is called with %rip on the upper half, so it can
+// NOT access to any global data which is not mapped on upper and must
+// call to function pointers or interfaces to switch to the lower half
+// so that callee can access to global data.
+//
// +checkescape:hard,stack
//
//go:nosplit
@@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) {
// Init initializes a new CPU.
//
// Init allows embedding in other objects.
-func (c *CPU) Init(k *Kernel, hooks Hooks) {
- c.self = c // Set self reference.
- c.kernel = k // Set kernel reference.
- c.init() // Perform architectural init.
+func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) {
+ c.self = c // Set self reference.
+ c.kernel = k // Set kernel reference.
+ c.init(cpuID) // Perform architectural init.
// Require hooks.
if hooks != nil {
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index d37981dbf..3a9dff4cc 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -18,13 +18,42 @@ package ring0
import (
"encoding/binary"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/usermem"
)
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
+ entrySize := reflect.TypeOf(kernelEntry{}).Size()
+ var (
+ entries []kernelEntry
+ padding = 1
+ )
+ for {
+ entries = make([]kernelEntry, maxCPUs+padding-1)
+ totalSize := entrySize * uintptr(maxCPUs+padding-1)
+ addr := reflect.ValueOf(&entries[0]).Pointer()
+ if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize {
+ // The runtime forces power-of-2 alignment for allocations, and we are therefore
+ // safe once the first address is aligned and the chunk is at least a full page.
+ break
+ }
+ padding = padding << 1
+ }
+ k.cpuEntries = entries
+
+ k.globalIDT = &idt64{}
+ if reflect.TypeOf(idt64{}).Size() != usermem.PageSize {
+ panic("Size of globalIDT should be PageSize")
+ }
+ if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 {
+ panic("Allocated globalIDT should be page aligned")
+ }
+
// Setup the IDT, which is uniform.
for v, handler := range handlers {
// Allow Breakpoint and Overflow to be called from all
@@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) {
}
}
+func (k *Kernel) EntryRegions() map[uintptr]uintptr {
+ regions := make(map[uintptr]uintptr)
+
+ addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer()
+ size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries))
+ end, _ := usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ addr = reflect.ValueOf(k.globalIDT).Pointer()
+ size = reflect.TypeOf(idt64{}).Size()
+ end, _ = usermem.Addr(addr + size).RoundUp()
+ regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end)
+
+ return regions
+}
+
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
+ c.kernelEntry = &c.kernel.cpuEntries[cpuID]
+ c.cpuSelf = c
// Null segment.
c.gdt[0].setNull()
@@ -65,6 +112,7 @@ func (c *CPU) init() {
// Set the kernel stack pointer in the TSS (virtual address).
stackAddr := c.StackTop()
+ c.stackTop = stackAddr
c.tss.rsp0Lo = uint32(stackAddr)
c.tss.rsp0Hi = uint32(stackAddr >> 32)
c.tss.ist1Lo = uint32(stackAddr)
@@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool {
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
- kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)
+ c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID))
// Sanitize registers.
regs := switchOpts.Registers
@@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
- jumpToKernel() // Switch to upper half.
- writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
- vector = iret(c, regs)
+ vector = iret(c, regs, uintptr(userCR3))
} else {
- vector = sysret(c, regs)
+ vector = sysret(c, regs, uintptr(userCR3))
}
- writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
- jumpToUser() // Return to lower half.
SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
@@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
//go:nosplit
func start(c *CPU) {
// Save per-cpu & FS segment.
- WriteGS(kernelAddr(c))
+ WriteGS(kernelAddr(c.kernelEntry))
WriteFS(uintptr(c.registers.Fs_base))
// Initialize floating point.
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index 14774c5db..b294ccc7c 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -25,13 +25,13 @@ func HaltAndResume()
func HaltEl1SvcAndResume()
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts) {
+func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
// Save the root page tables.
k.PageTables = opts.PageTables
}
// init initializes architecture-specific state.
-func (c *CPU) init() {
+func (c *CPU) init(cpuID int) {
// Set the kernel stack pointer(virtual address).
c.registers.Sp = uint64(c.StackTop())
@@ -64,11 +64,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
regs.Pstate |= UserFlagsSet
LoadFloatingPoint(switchOpts.FloatingPointState)
- SetTLS(regs.TPIDR_EL0)
kernelExitToEl0()
- regs.TPIDR_EL0 = GetTLS()
SaveFloatingPoint(switchOpts.FloatingPointState)
vector = c.vecCode
diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go
index ca968a036..0ec5c3bc5 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.go
+++ b/pkg/sentry/platform/ring0/lib_amd64.go
@@ -61,21 +61,9 @@ func wrgsbase(addr uintptr)
// wrgsmsr writes to the GS_BASE MSR.
func wrgsmsr(addr uintptr)
-// writeCR3 writes the CR3 value.
-func writeCR3(phys uintptr)
-
-// readCR3 reads the current CR3 value.
-func readCR3() uintptr
-
// readCR2 reads the current CR2 value.
func readCR2() uintptr
-// jumpToKernel jumps to the kernel version of the current RIP.
-func jumpToKernel()
-
-// jumpToUser jumps to the user version of the current RIP.
-func jumpToUser()
-
// fninit initializes the floating point unit.
func fninit()
diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s
index 75d742750..2fe83568a 100644
--- a/pkg/sentry/platform/ring0/lib_amd64.s
+++ b/pkg/sentry/platform/ring0/lib_amd64.s
@@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8
BYTE $0x0f; BYTE $0x30; // WRMSR
RET
-// jumpToUser changes execution to the user address.
-//
-// This works by changing the return value to the user version.
-TEXT ·jumpToUser(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- NOTQ BX
- ANDQ BX, SP // Switch the stack.
- ANDQ BX, BP // Switch the frame pointer.
- ANDQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// jumpToKernel changes execution to the kernel address space.
-//
-// This works by changing the return value to the kernel version.
-TEXT ·jumpToKernel(SB),NOSPLIT,$0
- MOVQ 0(SP), AX
- MOVQ ·KernelStartAddress(SB), BX
- ORQ BX, SP // Switch the stack.
- ORQ BX, BP // Switch the frame pointer.
- ORQ BX, AX // Future return value.
- MOVQ AX, 0(SP)
- RET
-
-// writeCR3 writes the given CR3 value.
-//
-// The code corresponds to:
-//
-// mov %rax, %cr3
-//
-TEXT ·writeCR3(SB),NOSPLIT,$0-8
- MOVQ cr3+0(FP), AX
- BYTE $0x0f; BYTE $0x22; BYTE $0xd8;
- RET
-
-// readCR3 reads the current CR3 value.
-//
-// The code corresponds to:
-//
-// mov %cr3, %rax
-//
-TEXT ·readCR3(SB),NOSPLIT,$0-8
- BYTE $0x0f; BYTE $0x20; BYTE $0xd8;
- MOVQ AX, ret+0(FP)
- RET
-
// readCR2 reads the current CR2 value.
//
// The code corresponds to:
diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go
index b8ab120a0..ca4075b09 100644
--- a/pkg/sentry/platform/ring0/offsets_amd64.go
+++ b/pkg/sentry/platform/ring0/offsets_amd64.go
@@ -30,14 +30,21 @@ func Emit(w io.Writer) {
c := &CPU{}
fmt.Fprintf(w, "\n// CPU offsets.\n")
- fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
- fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer())
+
+ e := &kernelEntry{}
+ fmt.Fprintf(w, "\n// CPU entry offsets.\n")
+ fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer())
+ fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer())
fmt.Fprintf(w, "\n// Bits.\n")
fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF)
+ fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0)
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
fmt.Fprintf(w, "\n// Vectors.\n")
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 1d86b4bcf..45eba960d 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -125,4 +125,5 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_TLS 0x%02x\n", reflect.ValueOf(&p.TPIDR_EL0).Pointer()-reflect.ValueOf(p).Pointer())
}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index 6409d1d91..520161755 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -78,7 +78,7 @@ const (
const (
executeDisable = xn
- optionMask = 0xfff | 0xfff<<48
+ optionMask = 0xfff | 0xffff<<48
protDefault = accessed | shared
)
@@ -188,7 +188,7 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) {
v |= mtNormal
} else {
v = v &^ user
- v |= mtDevicenGnRE // Strong order for the addresses with ring0.KernelStartAddress.
+ v |= mtNormal
}
atomic.StoreUintptr((*uintptr)(p), v)
}
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
index 9da0ea685..34fbc1c35 100644
--- a/pkg/sentry/platform/ring0/x86.go
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -39,7 +39,9 @@ const (
_RFLAGS_AC = 1 << 18
_RFLAGS_NT = 1 << 14
- _RFLAGS_IOPL = 3 << 12
+ _RFLAGS_IOPL0 = 1 << 12
+ _RFLAGS_IOPL1 = 1 << 13
+ _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1
_RFLAGS_DF = 1 << 10
_RFLAGS_IF = 1 << 9
_RFLAGS_STEP = 1 << 8
@@ -67,15 +69,45 @@ const (
KernelFlagsSet = _RFLAGS_RESERVED
// UserFlagsSet are always set in userspace.
- UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF
+ //
+ // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege
+ // level. The Current Privilege Level (CPL) of the task must be less
+ // than or equal to the IOPL in order for the task or program to access
+ // I/O ports.
+ //
+ // Here, _RFLAGS_IOPL0 is used only to determine whether the task is
+ // running in the kernel or userspace mode. In the user mode, the CPL is
+ // always 3 and it doesn't matter what IOPL is set if it is bellow CPL.
+ //
+ // We need to have one bit which will be always different in user and
+ // kernel modes. And we have to remember that even though we have
+ // KernelFlagsClear, we still can see some of these flags in the kernel
+ // mode. This can happen when the goruntime switches on a goroutine
+ // which has been saved in the host mode. On restore, the popf
+ // instruction is used to restore flags and this means that all flags
+ // what the goroutine has in the host mode will be restored in the
+ // kernel mode.
+ //
+ // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set
+ // it in the user mode. So if this flag is set, the task is running in
+ // the user mode and if it isn't set, the task is running in the kernel
+ // mode.
+ UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0
// KernelFlagsClear should always be clear in the kernel.
KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT
// UserFlagsClear are always cleared in userspace.
- UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL
+ UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1
)
+// IsKernelFlags returns true if rflags coresponds to the kernel mode.
+//
+// go:nosplit
+func IsKernelFlags(rflags uint64) bool {
+ return rflags&_RFLAGS_IOPL0 == 0
+}
+
// Vector is an exception vector.
type Vector uintptr
@@ -104,7 +136,7 @@ const (
VirtualizationException
SecurityException = 0x1e
SyscallInt80 = 0x80
- _NR_INTERRUPTS = SyscallInt80 + 1
+ _NR_INTERRUPTS = 0x100
)
// System call vectors.
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 94cb80437..904a12e38 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -147,10 +147,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
case stack.FilterTable:
table = stack.EmptyFilterTable()
case stack.NATTable:
- if ipv6 {
- nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)")
- return syserr.ErrInvalidArgument
- }
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 19b18b2d6..0e14447fe 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -47,6 +47,9 @@ func init() {
registerTargetMaker(&redirectTargetMaker{
NetworkProtocol: header.IPv4ProtocolNumber,
})
+ registerTargetMaker(&nfNATTargetMaker{
+ NetworkProtocol: header.IPv6ProtocolNumber,
+ })
}
type standardTargetMaker struct {
@@ -250,6 +253,86 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return &target, nil
}
+type nfNATTarget struct {
+ Target linux.XTEntryTarget
+ Range linux.NFNATRange
+}
+
+const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
+
+type nfNATTargetMaker struct {
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+func (rm *nfNATTargetMaker) id() stack.TargetID {
+ return stack.TargetID{
+ Name: stack.RedirectTargetName,
+ NetworkProtocol: rm.NetworkProtocol,
+ }
+}
+
+func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
+ rt := target.(*stack.RedirectTarget)
+ nt := nfNATTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: nfNATMarhsalledSize,
+ },
+ Range: linux.NFNATRange{
+ Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
+ },
+ }
+ copy(nt.Target.Name[:], stack.RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.Addr)
+ copy(nt.Range.MaxAddr[:], rt.Addr)
+
+ nt.Range.MinProto = htons(rt.Port)
+ nt.Range.MaxProto = nt.Range.MinProto
+
+ ret := make([]byte, 0, nfNATMarhsalledSize)
+ return binary.Marshal(ret, usermem.ByteOrder, nt)
+}
+
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+ if size := nfNATMarhsalledSize; len(buf) < size {
+ nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ nflog("nfNATTargetMaker: bad proto %d", p)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var natRange linux.NFNATRange
+ buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
+ binary.Unmarshal(buf, usermem.ByteOrder, &natRange)
+
+ // We don't support port or address ranges.
+ if natRange.MinAddr != natRange.MaxAddr {
+ nflog("nfNATTargetMaker: MinAddr and MaxAddr are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+ if natRange.MinProto != natRange.MaxProto {
+ nflog("nfNATTargetMaker: MinProto and MaxProto are different")
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/3549): Check for other flags.
+ // For now, redirect target only supports destination change.
+ if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED {
+ nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ target := stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Addr: tcpip.Address(natRange.MinAddr[:]),
+ Port: ntohs(natRange.MinProto),
+ }
+
+ return &target, nil
+}
+
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) {
@@ -306,7 +389,7 @@ func (jt *JumpTarget) ID() stack.TargetID {
}
// Action implements stack.Target.Action.
-func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 335822c0e..87e30d742 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1512,8 +1512,17 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return &vP, nil
case linux.IP6T_ORIGINAL_DST:
- // TODO(gvisor.dev/issue/170): ip6tables.
- return nil, syserr.ErrInvalidArgument
+ if outLen < int(binary.Size(linux.SockAddrInet6{})) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OriginalDestinationOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
if outLen < linux.SizeOfIPTGetinfo {
@@ -1555,6 +1564,26 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
return &entries, nil
+ case linux.IP6T_SO_GET_REVISION_TARGET:
+ if outLen < linux.SizeOfXTGetRevision {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return &ret, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 75752b2e6..a2e441448 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -21,6 +21,7 @@ go_library(
"sys_identity.go",
"sys_inotify.go",
"sys_lseek.go",
+ "sys_membarrier.go",
"sys_mempolicy.go",
"sys_mmap.go",
"sys_mount.go",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 5f26697d2..9c9def7cd 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -376,7 +376,7 @@ var AMD64 = &kernel.SyscallTable{
321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
322: syscalls.Supported("execveat", Execveat),
323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 324: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls implemented after 325 are "backports" from versions
@@ -527,8 +527,8 @@ var ARM64 = &kernel.SyscallTable{
96: syscalls.Supported("set_tid_address", SetTidAddress),
97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
- 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 99: syscalls.Supported("set_robust_list", SetRobustList),
+ 100: syscalls.Supported("get_robust_list", GetRobustList),
101: syscalls.Supported("nanosleep", Nanosleep),
102: syscalls.Supported("getitimer", Getitimer),
103: syscalls.Supported("setitimer", Setitimer),
@@ -695,7 +695,7 @@ var ARM64 = &kernel.SyscallTable{
280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
281: syscalls.Supported("execveat", Execveat),
282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 283: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil),
284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
// Syscalls after 284 are "backports" from versions of Linux after 4.4.
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 98331eb3c..519066a47 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -84,6 +84,7 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro
}
rel = f.Dirent
if !fs.IsDir(rel.Inode.StableAttr) {
+ f.DecRef(t)
return syserror.ENOTDIR
}
}
diff --git a/pkg/sentry/syscalls/linux/sys_membarrier.go b/pkg/sentry/syscalls/linux/sys_membarrier.go
new file mode 100644
index 000000000..63ee5d435
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_membarrier.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Membarrier implements syscall membarrier(2).
+func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ cmd := args[0].Int()
+ flags := args[1].Uint()
+
+ switch cmd {
+ case linux.MEMBARRIER_CMD_QUERY:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ var supportedCommands uintptr
+ if t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ supportedCommands |= linux.MEMBARRIER_CMD_GLOBAL |
+ linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED |
+ linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED |
+ linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED |
+ linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED
+ }
+ if t.RSeqAvailable() {
+ supportedCommands |= linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ |
+ linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ
+ }
+ return supportedCommands, nil, nil
+ case linux.MEMBARRIER_CMD_GLOBAL, linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ if cmd == linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED && !t.MemoryManager().IsMembarrierPrivateEnabled() {
+ return 0, nil, syserror.EPERM
+ }
+ return 0, nil, t.Kernel().Platform.GlobalMemoryBarrier()
+ case linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ // no-op
+ return 0, nil, nil
+ case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.Kernel().Platform.HaveGlobalMemoryBarrier() {
+ return 0, nil, syserror.EINVAL
+ }
+ t.MemoryManager().EnableMembarrierPrivate()
+ return 0, nil, nil
+ case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ:
+ if flags&^linux.MEMBARRIER_CMD_FLAG_CPU != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.RSeqAvailable() {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.MemoryManager().IsMembarrierRSeqEnabled() {
+ return 0, nil, syserror.EPERM
+ }
+ // MEMBARRIER_CMD_FLAG_CPU and cpu_id are ignored since we don't have
+ // the ability to preempt specific CPUs.
+ return 0, nil, t.Kernel().Platform.PreemptAllCPUs()
+ case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ:
+ if flags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if !t.RSeqAvailable() {
+ return 0, nil, syserror.EINVAL
+ }
+ t.MemoryManager().EnableMembarrierRSeq()
+ return 0, nil, nil
+ default:
+ // Probably a command we don't implement.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
index 674d341b6..6320593f0 100644
--- a/pkg/sentry/syscalls/linux/sys_sysinfo.go
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -26,8 +26,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
addr := args[0].Pointer()
mf := t.Kernel().MemoryFile()
- mf.UpdateUsage()
- _, totalUsage := usage.MemoryAccounting.Copy()
+ mfUsage, err := mf.TotalUsage()
+ if err != nil {
+ return 0, nil, err
+ }
+ memStats, _ := usage.MemoryAccounting.Copy()
+ totalUsage := mfUsage + memStats.Mapped
totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
memFree := totalSize - totalUsage
if memFree > totalSize {
@@ -37,12 +41,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
// Only a subset of the fields in sysinfo_t make sense to return.
si := linux.Sysinfo{
- Procs: uint16(len(t.PIDNamespace().Tasks())),
+ Procs: uint16(t.Kernel().TaskSet().Root.NumTasks()),
Uptime: t.Kernel().MonotonicClock().Now().Seconds(),
TotalRAM: totalSize,
FreeRAM: memFree,
Unit: 1,
}
- _, err := si.CopyOut(t, addr)
+ _, err = si.CopyOut(t, addr)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index 0df3bd449..c50fd97eb 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -163,6 +163,7 @@ func Override() {
// Override ARM64.
s = linux.ARM64
+ s.Table[2] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"})
s.Table[5] = syscalls.Supported("setxattr", SetXattr)
s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
@@ -200,6 +201,7 @@ func Override() {
s.Table[44] = syscalls.Supported("fstatfs", Fstatfs)
s.Table[45] = syscalls.Supported("truncate", Truncate)
s.Table[46] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[47] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil)
s.Table[48] = syscalls.Supported("faccessat", Faccessat)
s.Table[49] = syscalls.Supported("chdir", Chdir)
s.Table[50] = syscalls.Supported("fchdir", Fchdir)
@@ -221,12 +223,14 @@ func Override() {
s.Table[68] = syscalls.Supported("pwrite64", Pwrite64)
s.Table[69] = syscalls.Supported("preadv", Preadv)
s.Table[70] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[71] = syscalls.Supported("sendfile", Sendfile)
s.Table[72] = syscalls.Supported("pselect", Pselect)
s.Table[73] = syscalls.Supported("ppoll", Ppoll)
s.Table[74] = syscalls.Supported("signalfd4", Signalfd4)
s.Table[76] = syscalls.Supported("splice", Splice)
s.Table[77] = syscalls.Supported("tee", Tee)
s.Table[78] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[79] = syscalls.Supported("newfstatat", Newfstatat)
s.Table[80] = syscalls.Supported("fstat", Fstat)
s.Table[81] = syscalls.Supported("sync", Sync)
s.Table[82] = syscalls.Supported("fsync", Fsync)
@@ -251,8 +255,10 @@ func Override() {
s.Table[210] = syscalls.Supported("shutdown", Shutdown)
s.Table[211] = syscalls.Supported("sendmsg", SendMsg)
s.Table[212] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[213] = syscalls.Supported("readahead", Readahead)
s.Table[221] = syscalls.Supported("execve", Execve)
s.Table[222] = syscalls.Supported("mmap", Mmap)
+ s.Table[223] = syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil)
s.Table[242] = syscalls.Supported("accept4", Accept4)
s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg)
s.Table[267] = syscalls.Supported("syncfs", Syncfs)
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 8093ca55c..c855608db 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -92,7 +92,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index dfc3ae6c0..79a2d8c41 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -46,8 +46,9 @@ import (
// +stateify savable
type Mount struct {
// vfs, fs, root are immutable. References are held on fs and root.
+ // Note that for a disconnected mount, root may be nil.
//
- // Invariant: root belongs to fs.
+ // Invariant: if not nil, root belongs to fs.
vfs *VirtualFilesystem
fs *Filesystem
root *Dentry
@@ -498,7 +499,9 @@ func (mnt *Mount) DecRef(ctx context.Context) {
mnt.vfs.mounts.seq.EndWrite()
mnt.vfs.mountMu.Unlock()
}
- mnt.root.DecRef(ctx)
+ if mnt.root != nil {
+ mnt.root.DecRef(ctx)
+ }
mnt.fs.DecRef(ctx)
if vd.Ok() {
vd.DecRef(ctx)
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 5bd756ea5..31ea3139c 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -122,6 +122,13 @@ type VirtualFilesystem struct {
filesystems map[*Filesystem]struct{}
}
+// Release drops references on filesystem objects held by vfs.
+//
+// Precondition: This must be called after VFS.Init() has succeeded.
+func (vfs *VirtualFilesystem) Release(ctx context.Context) {
+ vfs.anonMount.DecRef(ctx)
+}
+
// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
func (vfs *VirtualFilesystem) Init(ctx context.Context) error {
if vfs.mountpoints != nil {
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index ea0c5413d..8db70a700 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -84,8 +84,8 @@ type VectorisedView struct {
size int
}
-// NewVectorisedView creates a new vectorised view from an already-allocated slice
-// of View and sets its size.
+// NewVectorisedView creates a new vectorised view from an already-allocated
+// slice of View and sets its size.
func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
@@ -170,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) {
}
// Clone returns a clone of this VectorisedView.
-// If the buffer argument is large enough to contain all the Views of this VectorisedView,
-// the method will avoid allocations and use the buffer to store the Views of the clone.
+// If the buffer argument is large enough to contain all the Views of this
+// VectorisedView, the method will avoid allocations and use the buffer to
+// store the Views of the clone.
func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
@@ -209,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) {
return newFirst, true
}
-// Size returns the size in bytes of the entire content stored in the vectorised view.
+// Size returns the size in bytes of the entire content stored in the
+// vectorised view.
func (vv *VectorisedView) Size() int {
return vv.size
}
@@ -222,6 +224,12 @@ func (vv *VectorisedView) ToView() View {
if len(vv.views) == 1 {
return vv.views[0]
}
+ return vv.ToOwnedView()
+}
+
+// ToOwnedView returns a single view containing the content of the vectorised
+// view that vv does not own.
+func (vv *VectorisedView) ToOwnedView() View {
u := make([]byte, 0, vv.size)
for _, v := range vv.views {
u = append(u, v...)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 19627fa9b..d4d785cca 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -118,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker {
v = ip.HopLimit()
}
if v != ttl {
- t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl)
+ }
+ }
+}
+
+// IPFullLength creates a checker for the full IP packet length. The
+// expected size is checked against both the Total Length in the
+// header and the number of bytes received.
+func IPFullLength(packetLength uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ var v uint16
+ var l uint16
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TotalLength()
+ l = uint16(len(ip))
+ case header.IPv6:
+ v = ip.PayloadLength() + header.IPv6FixedHeaderSize
+ l = uint16(len(ip))
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip)
+ }
+ if l != packetLength {
+ t.Errorf("bad packet length, got = %d, want = %d", l, packetLength)
+ }
+ if v != packetLength {
+ t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength)
+ }
+ }
+}
+
+// IPv4HeaderLength creates a checker that checks the IPv4 Header length.
+func IPv4HeaderLength(headerLength int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if hl := ip.HeaderLength(); hl != uint8(headerLength) {
+ t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength)
+ }
+ default:
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip)
}
}
}
// PayloadLen creates a checker that checks the payload length.
-func PayloadLen(plen int) NetworkChecker {
+func PayloadLen(payloadLength int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- if l := len(h[0].Payload()); l != plen {
- t.Errorf("Bad payload length, got %v, want %v", l, plen)
+ if l := len(h[0].Payload()); l != payloadLength {
+ t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength)
+ }
+ }
+}
+
+// IPv4Options returns a checker that checks the options in an IPv4 packet.
+func IPv4Options(want []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ ip, ok := h[0].(header.IPv4)
+ if !ok {
+ t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
+ }
+ options := ip.Options()
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(want) == 0 && len(options) == 0 {
+ return
+ }
+ if diff := cmp.Diff(want, options); diff != "" {
+ t.Errorf("options mismatch (-want +got):\n%s", diff)
}
}
}
@@ -139,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
- t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset)
}
}
}
@@ -154,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
- // We only do this of IPv4 for now.
+ // We only do this for IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
- t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
+ t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags)
}
}
}
@@ -208,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
- t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label)
}
}
}
@@ -234,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
@@ -261,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber)
}
// Verify the checksum.
@@ -297,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker {
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
- t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
@@ -316,7 +380,7 @@ func SrcPort(port uint16) TransportChecker {
t.Helper()
if p := h.SourcePort(); p != port {
- t.Errorf("Bad source port, got %v, want %v", p, port)
+ t.Errorf("Bad source port, got = %d, want = %d", p, port)
}
}
}
@@ -327,7 +391,7 @@ func DstPort(port uint16) TransportChecker {
t.Helper()
if p := h.DestinationPort(); p != port {
- t.Errorf("Bad destination port, got %v, want %v", p, port)
+ t.Errorf("Bad destination port, got = %d, want = %d", p, port)
}
}
}
@@ -359,7 +423,7 @@ func TCPSeqNum(seq uint32) TransportChecker {
}
if s := tcp.SequenceNumber(); s != seq {
- t.Errorf("Bad sequence number, got %v, want %v", s, seq)
+ t.Errorf("Bad sequence number, got = %d, want = %d", s, seq)
}
}
}
@@ -375,7 +439,7 @@ func TCPAckNum(seq uint32) TransportChecker {
}
if s := tcp.AckNumber(); s != seq {
- t.Errorf("Bad ack number, got %v, want %v", s, seq)
+ t.Errorf("Bad ack number, got = %d, want = %d", s, seq)
}
}
}
@@ -492,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
- t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS)
}
foundMSS = true
i += 4
@@ -502,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
}
v := int(opts[i+2])
if v != wantOpts.WS {
- t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS)
}
foundWS = true
i += 3
@@ -551,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
- t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
@@ -589,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
- t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
@@ -609,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp
}
if wantTS != foundTS {
- t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
- t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
- t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr)
}
}
}
-// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
-// contain any SACK blocks in the TCP options.
+// TCPNoSACKBlockChecker creates a checker that verifies that the segment does
+// not contain any SACK blocks in the TCP options.
func TCPNoSACKBlockChecker() TransportChecker {
return TCPSACKBlockChecker(nil)
}
@@ -679,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
- t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks)
}
}
}
@@ -695,8 +759,8 @@ func Payload(want []byte) TransportChecker {
}
}
-// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
-// potentially additional ICMPv4 header fields.
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4
+// and potentially additional ICMPv4 header fields.
func ICMPv4(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
@@ -724,10 +788,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -739,10 +803,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
icmpv4, ok := h.(header.ICMPv4)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
}
if got := icmpv4.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident.
+func ICMPv4Ident(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Ident(); got != want {
+ t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence.
+func ICMPv4Seq(want uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Sequence(); got != want {
+ t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
+// This assumes that the payload exactly makes up the rest of the slice.
+func ICMPv4Checksum() TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ heldChecksum := icmpv4.Checksum()
+ icmpv4.SetChecksum(0)
+ newChecksum := ^header.Checksum(icmpv4, 0)
+ icmpv4.SetChecksum(heldChecksum)
+ if heldChecksum != newChecksum {
+ t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum)
+ }
+ }
+}
+
+// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet.
+func ICMPv4Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ payload := icmpv4.Payload()
+ if diff := cmp.Diff(want, payload); diff != "" {
+ t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
}
}
}
@@ -782,10 +912,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Type(); got != want {
- t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want)
}
}
}
@@ -797,10 +927,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
icmpv6, ok := h.(header.ICMPv6)
if !ok {
- t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
}
if got := icmpv6.Code(); got != want {
- t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want)
}
}
}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index eaface8cb..95ade0e5c 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -117,25 +117,31 @@ func (b Ethernet) Encode(e *EthernetFields) {
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
-// IsValidUnicastEthernetAddress returns true if addr is a valid unicast
+// IsMulticastEthernetAddress returns true if the address is a multicast
+// ethernet address.
+func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool {
+ if len(addr) != EthernetAddressSize {
+ return false
+ }
+
+ return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0
+}
+
+// IsValidUnicastEthernetAddress returns true if the address is a unicast
// ethernet address.
func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
- // Must be of the right length.
if len(addr) != EthernetAddressSize {
return false
}
- // Must not be unspecified.
if addr == unspecifiedEthernetAddress {
return false
}
- // Must not be a multicast.
if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
return false
}
- // addr is a valid unicast ethernet address.
return true
}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 14413f2ce..3bc8b2b21 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -67,6 +67,53 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
}
}
+func TestIsMulticastEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.LinkAddress
+ expected bool
+ }{
+ {
+ "Nil",
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty",
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "InvalidLength",
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Unspecified",
+ unspecifiedEthernetAddress,
+ false,
+ },
+ {
+ "Multicast",
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ true,
+ },
+ {
+ "Unicast",
+ tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := IsMulticastEthernetAddress(test.addr); got != test.expected {
+ t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected)
+ }
+ })
+ }
+}
+
func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index c00bcadfb..504408878 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -126,15 +126,6 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
-// SetPointer sets the pointer field in a Parameter error packet.
-// This is the first byte of the type specific data field.
-func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
-
-// SetTypeSpecific sets the full 32 bit type specific data field.
-func (b ICMPv4) SetTypeSpecific(val uint32) {
- binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val)
-}
-
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 4eb5abd79..6be31beeb 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -156,9 +156,14 @@ const (
// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
const (
+ // ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
ICMPv6ErroneousHeader ICMPv6Code = 0
- ICMPv6UnknownHeader ICMPv6Code = 1
- ICMPv6UnknownOption ICMPv6Code = 2
+
+ // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
+ ICMPv6UnknownHeader ICMPv6Code = 1
+
+ // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
+ ICMPv6UnknownOption ICMPv6Code = 2
)
// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
@@ -177,7 +182,12 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
-// SetTypeSpecific sets the full 32 bit type specific data field.
+// TypeSpecific returns the type specific data field.
+func (b ICMPv6) TypeSpecific() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
+}
+
+// SetTypeSpecific sets the type specific data field.
func (b ICMPv6) SetTypeSpecific(val uint32) {
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index b07d9991d..4c6e4be64 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -16,10 +16,29 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// RFC 971 defines the fields of the IPv4 header on page 11 using the following
+// diagram: ("Figure 4")
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Identification |Flags| Fragment Offset |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Time to Live | Protocol | Header Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Source Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Destination Address |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options | Padding |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
const (
versIHL = 0
tos = 1
@@ -33,6 +52,7 @@ const (
checksum = 10
srcAddr = 12
dstAddr = 16
+ options = 20
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
@@ -76,7 +96,8 @@ type IPv4Fields struct {
// IPv4 represents an ipv4 header stored in a byte array.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
-// Always call IsValid() to validate an instance of IPv4 before using other methods.
+// Always call IsValid() to validate an instance of IPv4 before using other
+// methods.
type IPv4 []byte
const (
@@ -151,13 +172,44 @@ func IPVersion(b []byte) int {
if len(b) < versIHL+1 {
return -1
}
- return int(b[versIHL] >> 4)
+ return int(b[versIHL] >> ipVersionShift)
}
+// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits
+// of the first byte, and is counted in multiples of 4 bytes.
+//
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// |Version| IHL |Type of Service| Total Length |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// (...)
+// Version: 4 bits
+// The Version field indicates the format of the internet header. This
+// document describes version 4.
+//
+// IHL: 4 bits
+// Internet Header Length is the length of the internet header in 32
+// bit words, and thus points to the beginning of the data. Note that
+// the minimum value for a correct header is 5.
+//
+const (
+ ipVersionShift = 4
+ ipIHLMask = 0x0f
+ IPv4IHLStride = 4
+)
+
// HeaderLength returns the value of the "header length" field of the ipv4
// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
- return (b[versIHL] & 0xf) * 4
+ return (b[versIHL] & ipIHLMask) * IPv4IHLStride
+}
+
+// SetHeaderLength sets the value of the "Internet Header Length" field.
+func (b IPv4) SetHeaderLength(hdrLen uint8) {
+ if hdrLen > IPv4MaximumHeaderSize {
+ panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize))
+ }
+ b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask)
}
// ID returns the value of the identifier field of the ipv4 header.
@@ -211,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// Options returns a a buffer holding the options.
+func (b IPv4) Options() []byte {
+ hdrLen := b.HeaderLength()
+ return b[options:hdrLen:hdrLen]
+}
+
// TransportProtocol implements Network.TransportProtocol.
func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.Protocol())
@@ -236,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
+// SetTTL sets the "Time to Live" field of the IPv4 header.
+func (b IPv4) SetTTL(v byte) {
+ b[ttl] = v
+}
+
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
@@ -276,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 {
// Encode encodes all the fields of the ipv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
- b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b.SetHeaderLength(i.IHL)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 0761a1807..c5d8a3456 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -34,6 +34,9 @@ const (
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
+
+ // IPv6FixedHeaderSize is the size of the fixed header.
+ IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -69,7 +72,7 @@ type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
- IPv6MinimumSize = 40
+ IPv6MinimumSize = IPv6FixedHeaderSize
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
@@ -306,14 +309,21 @@ func IsV6UnicastAddress(addr tcpip.Address) bool {
return addr[0] != 0xff
}
+const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
+
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
- const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
+// IsSolicitedNodeAddr determines whether the address is a solicited-node
+// multicast address.
+func IsSolicitedNodeAddr(addr tcpip.Address) bool {
+ return solicitedNodeMulticastPrefix == addr[:len(addr)-3]
+}
+
// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
//
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
index 3499d8399..583c2c5d3 100644
--- a/pkg/tcpip/header/ipv6_extension_headers.go
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
// obtained before modification is no longer used.
type IPv6OptionsExtHdrOptionsIterator struct {
reader bytes.Reader
+
+ // optionOffset is the number of bytes from the first byte of the
+ // options field to the beginning of the current option.
+ optionOffset uint32
+
+ // nextOptionOffset is the offset of the next option.
+ nextOptionOffset uint32
+}
+
+// OptionOffset returns the number of bytes parsed while processing the
+// option field of the current Extension Header.
+func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
+ return i.optionOffset
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
@@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
// the options data, or an error occured.
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
for {
+ i.optionOffset = i.nextOptionOffset
temp, err := i.reader.ReadByte()
if err != nil {
// If we can't read the first byte of a new option, then we know the
@@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
// know the option does not have Length and Data fields. End processing of
// the Pad1 option and continue processing the buffer as a new option.
if id == ipv6Pad1ExtHdrOptionIdentifier {
+ i.nextOptionOffset = i.optionOffset + 1
continue
}
@@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
}
- // Special-case the variable length padding option to avoid a copy.
- if id == ipv6PadNExtHdrOptionIdentifier {
- // Do we have enough bytes in the reader for the PadN option?
- if n := i.reader.Len(); n < int(length) {
- // Reset the reader to effectively consume the remaining buffer.
- i.reader.Reset(nil)
-
- // We return the same error as if we failed to read a non-padding option
- // so consumers of this iterator don't need to differentiate between
- // padding and non-padding options.
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
- }
+ // Do we have enough bytes in the reader for the next option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
+ switch id {
+ case ipv6PadNExtHdrOptionIdentifier:
+ // Special-case the variable length padding option to avoid a copy.
if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
}
-
- // End processing of the PadN option and continue processing the buffer as
- // a new option.
continue
- }
-
- bytes := make([]byte, length)
- if n, err := io.ReadFull(&i.reader, bytes); err != nil {
- // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
- // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
- // Length field found in the option.
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
+ default:
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
}
-
- return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
-
- return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
}
}
@@ -382,6 +396,29 @@ type IPv6PayloadIterator struct {
// Indicates to the iterator that it should return the remaining payload as a
// raw payload on the next call to Next.
forceRaw bool
+
+ // headerOffset is the offset of the beginning of the current extension
+ // header starting from the beginning of the fixed header.
+ headerOffset uint32
+
+ // parseOffset is the byte offset into the current extension header of the
+ // field we are currently examining. It can be added to the header offset
+ // if the absolute offset within the packet is required.
+ parseOffset uint32
+
+ // nextOffset is the offset of the next header.
+ nextOffset uint32
+}
+
+// HeaderOffset returns the offset to the start of the extension
+// header most recently processed.
+func (i IPv6PayloadIterator) HeaderOffset() uint32 {
+ return i.headerOffset
+}
+
+// ParseOffset returns the number of bytes successfully parsed.
+func (i IPv6PayloadIterator) ParseOffset() uint32 {
+ return i.headerOffset + i.parseOffset
}
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
@@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa
nextHdrIdentifier: nextHdrIdentifier,
payload: payload.Clone(nil),
// We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
- reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ nextOffset: IPv6FixedHeaderSize,
}
}
@@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
// Next is unable to return anything because the iterator has reached the end of
// the payload, or an error occured.
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ i.headerOffset = i.nextOffset
+ i.parseOffset = 0
// We could be forced to return i as a raw header when the previous header was
// a fragment extension header as the data following the fragment extension
// header may not be complete.
@@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
return IPv6RoutingExtHdr(bytes), false, nil
case IPv6FragmentExtHdrIdentifier:
var data [6]byte
- // We ignore the returned bytes becauase we know the fragment extension
+ // We ignore the returned bytes because we know the fragment extension
// header specific data will fit in data.
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
if err != nil {
@@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
if err != nil {
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
}
+ i.parseOffset++
var length uint8
length, err = i.reader.ReadByte()
i.payload.TrimFront(1)
+
if err != nil {
if fragmentHdr {
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
@@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP
length = 0
}
+ // Make parseOffset point to the first byte of the Extension Header
+ // specific data.
+ i.parseOffset++
+
+ // length is in 8 byte chunks but doesn't include the first one.
+ // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
+ // in section 4.8 for new extension headers at the top of page 24.
+ // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
+ // units, not including the first 8 octets.
+ i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
+
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
if bytes == nil {
bytes = make([]byte, bytesLen)
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
index b5540bf66..17a49d4fa 100644
--- a/pkg/tcpip/header/ipversion_test.go
+++ b/pkg/tcpip/header/ipversion_test.go
@@ -22,7 +22,7 @@ import (
func TestIPv4(t *testing.T) {
b := header.IPv4(make([]byte, header.IPv4MinimumSize))
- b.Encode(&header.IPv4Fields{})
+ b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize})
const want = header.IPv4Version
if v := header.IPVersion(b); v != want {
diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD
new file mode 100644
index 000000000..9f31c1ffc
--- /dev/null
+++ b/pkg/tcpip/link/pipe/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pipe",
+ srcs = ["pipe.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
new file mode 100644
index 000000000..76f563811
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides the implementation of pipe-like data-link layer
+// endpoints. Such endpoints allow packets to be sent between two interfaces.
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns both ends of a new pipe.
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) {
+ ep1 := &Endpoint{
+ linkAddr: linkAddr1,
+ capabilities: capabilities,
+ }
+ ep2 := &Endpoint{
+ linkAddr: linkAddr2,
+ linked: ep1,
+ capabilities: capabilities,
+ }
+ ep1.linked = ep2
+ return ep1, ep2
+}
+
+// Endpoint is one end of a pipe.
+type Endpoint struct {
+ capabilities stack.LinkEndpointCapabilities
+ linkAddr tcpip.LinkAddress
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ onWritePacket func(*stack.PacketBuffer)
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if !e.linked.IsAttached() {
+ return nil
+ }
+
+ // The pipe endpoint will accept all multicast/broadcast link traffic and only
+ // unicast traffic destined to itself.
+ if len(e.linked.linkAddr) != 0 &&
+ r.RemoteLinkAddress != e.linked.linkAddr &&
+ r.RemoteLinkAddress != header.EthernetBroadcastAddress &&
+ !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) {
+ return nil
+ }
+
+ e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ panic("not implemented")
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// Wait implements stack.LinkEndpoint.
+func (*Endpoint) Wait() {}
+
+// MTU implements stack.LinkEndpoint.
+func (*Endpoint) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index b6ddbe81e..f94491026 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -76,13 +76,29 @@ func (d *Device) Release(ctx context.Context) {
}
}
+// NICID returns the NIC ID of the device.
+//
+// Must only be called after the device has been attached to an endpoint.
+func (d *Device) NICID() tcpip.NICID {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+
+ if d.endpoint == nil {
+ panic("called NICID on a device that has not been attached")
+ }
+
+ return d.endpoint.nicID
+}
+
// SetIff services TUNSETIFF ioctl(2) request.
-func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+//
+// Returns true if a new NIC was created; false if an existing one was attached.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
// Input validations.
@@ -90,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
isTap := flags&linux.IFF_TAP != 0
supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
prefix := "tun"
@@ -103,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
linkCaps |= stack.CapabilityResolutionRequired
}
- endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
d.endpoint = endpoint
d.notifyHandle = d.endpoint.AddNotify(d)
d.flags = flags
- return nil
+ return created, nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
- if nic, found := s.GetNICByName(name); found {
- endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if linkEP := s.GetLinkEndpointByName(name); linkEP != nil {
+ endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, syserror.EOPNOTSUPP
+ return nil, false, syserror.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
continue
}
- return endpoint, nil
+ return endpoint, false, nil
}
}
@@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
})
switch err {
case nil:
- return endpoint, nil
+ return endpoint, true, nil
case tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, false, syserror.EINVAL
}
}
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b47a7be51..7df77c66e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -49,7 +49,6 @@ type endpoint struct {
enabled uint32
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
}
@@ -92,12 +91,12 @@ func (e *endpoint) DefaultTTL() uint8 {
}
func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
+ lmtu := e.nic.MTU()
return lmtu - uint32(e.MaxHeaderLength())
}
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.ARPSize
+ return e.nic.MaxHeaderLength() + header.ARPSize
}
func (e *endpoint) Close() {
@@ -154,17 +153,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ origSender := h.HardwareAddressSender()
+ r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
+ respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
- packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.HardwareAddressTarget(), origSender)
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
@@ -207,7 +214,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
e := &endpoint{
protocol: p,
nic: nic,
- linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
}
@@ -223,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
+ NetProto: ProtocolNumber,
RemoteLinkAddress: remoteLinkAddr,
}
if len(r.RemoteLinkAddress) == 0 {
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index e247f06a4..47fb63290 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -29,6 +29,8 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
],
)
@@ -44,5 +46,7 @@ go_test(
deps = [
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
+ "//pkg/tcpip/network/testutil",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index e1909fab0..ed502a473 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -13,7 +13,7 @@
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
-// It is based on RFC 791 and RFC 815.
+// It is based on RFC 791, RFC 815 and RFC 8200.
package fragmentation
import (
@@ -25,12 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
- // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
- DefaultReassembleTimeout = 30 * time.Second
-
// HighFragThreshold is the threshold at which we start trimming old
// fragmented packets. Linux uses a default value of 4 MB. See
// net.ipv4.ipfrag_high_thresh for more information.
@@ -243,3 +241,78 @@ func (f *Fragmentation) releaseReassemblersLocked() {
f.release(r)
}
}
+
+// PacketFragmenter is the book-keeping struct for packet fragmentation.
+type PacketFragmenter struct {
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ innerMTU int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
+}
+
+// MakePacketFragmenter prepares the struct needed for packet fragmentation.
+//
+// pkt is the packet to be fragmented.
+//
+// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// have.
+//
+// reserve is the number of bytes that should be reserved for the headers in
+// each generated fragment.
+func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+ // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
+ // repeated in each fragment. However we do not currently support any header
+ // of that kind yet, so the following computation is valid for both IPv4 and
+ // IPv6.
+ // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
+ // supported for outbound packets, the fragmentable data should not include
+ // these headers.
+ var fragmentableData buffer.VectorisedView
+ fragmentableData.AppendView(pkt.TransportHeader().View())
+ fragmentableData.Append(pkt.Data)
+ fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+
+ return PacketFragmenter{
+ data: fragmentableData,
+ reserve: reserve,
+ innerMTU: innerMTU,
+ fragmentCount: fragmentCount,
+ }
+}
+
+// BuildNextFragment returns a packet with the payload of the next fragment,
+// along with the fragment's offset, the number of bytes copied and a boolean
+// indicating if there are more fragments left or not. If this function is
+// called again after it indicated that no more fragments were left, it will
+// panic.
+//
+// Note that the returned packet will not have its network and link headers
+// populated, but space for them will be reserved. The transport header will be
+// stored in the packet's data.
+func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
+ if pf.currentFragment >= pf.fragmentCount {
+ panic("BuildNextFragment should not be called again after the last fragment was returned")
+ }
+
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: pf.reserve,
+ })
+
+ // Copy data for the fragment.
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+
+ offset := pf.fragmentOffset
+ pf.fragmentOffset += copied
+ pf.currentFragment++
+ more := pf.currentFragment != pf.fragmentCount
+
+ return fragPkt, offset, copied, more
+}
+
+// RemainingFragmentCount returns the number of fragments left to be built.
+func (pf *PacketFragmenter) RemainingFragmentCount() int {
+ return pf.fragmentCount - pf.currentFragment
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 189b223c5..d3c7d7f92 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -20,10 +20,16 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
)
+// reassembleTimeout is dummy timeout used for testing, where the clock never
+// advances.
+const reassembleTimeout = 1
+
// vv is a helper to build VectorisedView from different strings.
func vv(size int, pieces ...string) buffer.VectorisedView {
views := make([]buffer.View, len(pieces))
@@ -96,7 +102,7 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
@@ -251,7 +257,7 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
// Send first fragment with id = 1.
@@ -275,7 +281,7 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Send the same packet again.
@@ -370,7 +376,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{})
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
_, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
@@ -381,3 +387,113 @@ func TestErrors(t *testing.T) {
})
}
}
+
+type fragmentInfo struct {
+ remaining int
+ copied int
+ offset int
+ more bool
+}
+
+func TestPacketFragmenter(t *testing.T) {
+ const (
+ reserve = 60
+ proto = 0
+ )
+
+ tests := []struct {
+ name string
+ innerMTU int
+ transportHeaderLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ }{
+ {
+ name: "Packet exactly fits in MTU",
+ innerMTU: 1280,
+ transportHeaderLen: 0,
+ payloadSize: 1280,
+ wantFragments: []fragmentInfo{
+ {remaining: 0, copied: 1280, offset: 0, more: false},
+ },
+ },
+ {
+ name: "Packet exactly does not fit in MTU",
+ innerMTU: 1000,
+ transportHeaderLen: 0,
+ payloadSize: 1001,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 1000, offset: 0, more: true},
+ {remaining: 0, copied: 1, offset: 1000, more: false},
+ },
+ },
+ {
+ name: "Packet has a transport header",
+ innerMTU: 560,
+ transportHeaderLen: 40,
+ payloadSize: 560,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 560, offset: 0, more: true},
+ {remaining: 0, copied: 40, offset: 560, more: false},
+ },
+ },
+ {
+ name: "Packet has a huge transport header",
+ innerMTU: 500,
+ transportHeaderLen: 1300,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {remaining: 3, copied: 500, offset: 0, more: true},
+ {remaining: 2, copied: 500, offset: 500, more: true},
+ {remaining: 1, copied: 500, offset: 1000, more: true},
+ {remaining: 0, copied: 300, offset: 1500, more: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
+ var originalPayload buffer.VectorisedView
+ originalPayload.AppendView(pkt.TransportHeader().View())
+ originalPayload.Append(pkt.Data)
+ var reassembledPayload buffer.VectorisedView
+ pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ for i := 0; ; i++ {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ wantFragment := test.wantFragments[i]
+ if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
+ t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
+ }
+ if copied != wantFragment.copied {
+ t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
+ }
+ if offset != wantFragment.offset {
+ t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
+ }
+ if more != wantFragment.more {
+ t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
+ }
+ if got := fragPkt.Size(); got > test.innerMTU {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ }
+ if got := fragPkt.AvailableHeaderBytes(); got != reserve {
+ t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
+ }
+ if got := fragPkt.TransportHeader().View().Size(); got != 0 {
+ t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
+ }
+ reassembledPayload.Append(fragPkt.Data)
+ if !more {
+ if i != len(test.wantFragments)-1 {
+ t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
+ }
+ break
+ }
+ }
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6861cfdaf..d436873b6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -270,7 +270,7 @@ func buildDummyStack(t *testing.T) *stack.Stack {
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
- tester testObject
+ testObject
mu struct {
sync.RWMutex
@@ -302,10 +302,6 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
-func (t *testInterface) LinkEndpoint() stack.LinkEndpoint {
- return &t.tester
-}
-
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
@@ -517,7 +513,7 @@ func TestIPv4Send(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
@@ -538,10 +534,10 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- nic.tester.protocol = 123
- nic.tester.srcAddr = localIPv4Addr
- nic.tester.dstAddr = remoteIPv4Addr
- nic.tester.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv4Addr
+ nic.testObject.dstAddr = remoteIPv4Addr
+ nic.testObject.contents = payload
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -560,12 +556,12 @@ func TestIPv4Receive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -590,10 +586,10 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -606,8 +602,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -640,11 +636,11 @@ func TestIPv4ReceiveControl(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -691,16 +687,16 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = view[dataOffset:]
- nic.tester.typ = c.expectedTyp
- nic.tester.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; nic.tester.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
@@ -710,12 +706,12 @@ func TestIPv4FragmentationReceive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
v4: true,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -758,10 +754,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv4Addr
- nic.tester.dstAddr = localIPv4Addr
- nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
@@ -776,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
// Send second segment.
@@ -788,8 +784,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -797,7 +793,7 @@ func TestIPv6Send(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
@@ -821,10 +817,10 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- nic.tester.protocol = 123
- nic.tester.srcAddr = localIPv6Addr
- nic.tester.dstAddr = remoteIPv6Addr
- nic.tester.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv6Addr
+ nic.testObject.dstAddr = remoteIPv6Addr
+ nic.testObject.contents = payload
r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
@@ -843,11 +839,11 @@ func TestIPv6Receive(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -871,10 +867,10 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv6Addr
- nic.tester.dstAddr = localIPv6Addr
- nic.tester.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
@@ -888,8 +884,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if nic.tester.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -931,11 +927,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
nic := testInterface{
- tester: testObject{
+ testObject: testObject{
t: t,
},
}
- ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -994,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- nic.tester.protocol = 10
- nic.tester.srcAddr = remoteIPv6Addr
- nic.tester.dstAddr = localIPv6Addr
- nic.tester.contents = view[dataOffset:]
- nic.tester.typ = c.expectedTyp
- nic.tester.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; nic.tester.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 0a7e98ed1..7fc12e229 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -28,12 +28,15 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 5c4f715d7..3407755ed 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,8 @@
package ipv4
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -77,31 +79,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Echo.Increment()
// Only send a reply if the checksum is valid.
- wantChecksum := h.Checksum()
- // Reset the checksum field to 0 to can calculate the proper
- // checksum. We'll have to reset this before we hand the packet
- // off.
+ headerChecksum := h.Checksum()
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- if gotChecksum != wantChecksum {
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
+ calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(headerChecksum)
+ if calculatedChecksum != headerChecksum {
+ // It's possible that a raw socket still expects to receive this.
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
- // Make a copy of data before pkt gets sent to raw socket.
- // DeliverTransportPacket will take ownership of pkt.
- replyData := pkt.Data.Clone(nil)
- replyData.TrimFront(header.ICMPv4MinimumSize)
+ // DeliverTransportPacket will take ownership of pkt so don't use it beyond
+ // this point. Make a deep copy of the data before pkt gets sent as we will
+ // be modifying fields.
+ //
+ // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
+ // waiting endpoints. Consider moving responsibility for doing the copy to
+ // DeliverTransportPacket so that is is only done when needed.
+ replyData := pkt.Data.ToOwnedView()
+ replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
@@ -117,32 +117,49 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
}
defer r.Release()
- // Use the remote link address from the incoming packet.
- r.ResolveWith(remoteLinkAddr)
-
- // Prepare a reply packet.
- icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize)
- copy(icmpHdr, h)
- icmpHdr.SetType(header.ICMPv4EchoReply)
- icmpHdr.SetChecksum(0)
- icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0)))
- dataVV := buffer.View(icmpHdr).ToVectorisedView()
- dataVV.Append(replyData)
+ // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
+ // header information, we may have to change this code to handle the
+ // ICMP header no longer being in the data buffer.
+
+ // Because IP and ICMP are so closely intertwined, we need to handcraft our
+ // IP header to be able to follow RFC 792. The wording on page 13 is as
+ // follows:
+ // IP Fields:
+ // Addresses
+ // The address of the source in an echo message will be the
+ // destination of the echo reply message. To form an echo reply
+ // message, the source and destination addresses are simply reversed,
+ // the type code changed to 0, and the checksum recomputed.
+ //
+ // This was interpreted by early implementors to mean that all options must
+ // be copied from the echo request IP header to the echo reply IP header
+ // and this behaviour is still relied upon by some applications.
+ //
+ // Create a copy of the IP header we received, options and all, and change
+ // The fields we need to alter.
+ //
+ // We need to produce the entire packet in the data segment in order to
+ // use WriteHeaderIncludedPacket().
+ replyIPHdr.SetSourceAddress(r.LocalAddress)
+ replyIPHdr.SetDestinationAddress(r.RemoteAddress)
+ replyIPHdr.SetTTL(r.DefaultTTL())
+
+ replyICMPHdr := header.ICMPv4(replyData)
+ replyICMPHdr.SetType(header.ICMPv4EchoReply)
+ replyICMPHdr.SetChecksum(0)
+ replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0))
+
+ replyVV := buffer.View(replyIPHdr).ToVectorisedView()
+ replyVV.AppendView(replyData)
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: dataVV,
+ Data: replyVV,
})
- // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header
- // information we will have to change this code to handle the ICMP header
- // no longer being in the data buffer.
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- // Send out the reply packet.
+
+ // The checksum will be calculated so we don't need to do it here.
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
- Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, replyPkt); err != nil {
+ if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -211,18 +228,18 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonProtoUnreachable is an error where the transport protocol is
+// not supported.
+type icmpReasonProtoUnreachable struct{}
+
+func (*icmpReasonProtoUnreachable) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- sent := r.Stats().ICMP.V4PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -251,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
return nil
}
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ sent := p.stack.Stats().ICMP.V4PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
@@ -287,8 +323,6 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Assume any type we don't know about may be an error type.
return nil
}
- } else if transportHeader.IsEmpty() {
- return nil
}
// Now work out how much of the triggering packet we should return.
@@ -303,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// least 8 bytes of the payload must be included. Today linux and other
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv4MinimumProcessableDatagramSize {
mtu = header.IPv4MinimumProcessableDatagramSize
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
@@ -336,19 +370,27 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
ReserveHeaderBytes: headerLen,
Data: payload,
})
+
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ switch reason.(type) {
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ case *icmpReasonProtoUnreachable:
+ icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter := sent.DstUnreachable
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ counter := sent.DstUnreachable
- if err := r.WritePacket(
+ if err := route.WritePacket(
nil, /* gso */
stack.NetworkHeaderParams{
Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
+ TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
},
icmpPkt,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index ad7a767a4..c5ac7b8b5 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -18,6 +18,7 @@ package ipv4
import (
"fmt"
"sync/atomic"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -30,6 +31,15 @@ import (
)
const (
+ // As per RFC 791 section 3.2:
+ // The current recommendation for the initial timer setting is 15 seconds.
+ // This may be changed as experience with this protocol accumulates.
+ //
+ // Considering that it is an old recommendation, we use the same reassembly
+ // timeout that linux defines, which is 30 seconds:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
+ reassembleTimeout = 30 * time.Second
+
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -56,7 +66,6 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil)
type endpoint struct {
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
@@ -77,7 +86,6 @@ type endpoint struct {
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: nic.LinkEndpoint(),
dispatcher: dispatcher,
protocol: p,
}
@@ -168,21 +176,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -190,98 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is already present in pkt.NetworkHeader.
-// pkt.TransportHeader may be set. mtu includes the IP header and options. This
-// does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
- // This packet is too big, it needs to be fragmented.
- ip := header.IPv4(pkt.NetworkHeader().View())
- flags := ip.Flags()
-
- // Update mtu to take into account the header, which will exist in all
- // fragments anyway.
- innerMTU := mtu - int(ip.HeaderLength())
-
- // Round the MTU down to align to 8 bytes. Then calculate the number of
- // fragments. Calculate fragment sizes as in RFC791.
- innerMTU &^= 7
- n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
-
- outerMTU := innerMTU + int(ip.HeaderLength())
- offset := ip.FragmentOffset()
-
- // Keep the length reserved for link-layer, we need to create fragments with
- // the same reserved length.
- reservedForLink := pkt.AvailableHeaderBytes()
-
- // Destroy the packet, pull all payloads out for fragmentation.
- transHeader, data := pkt.TransportHeader().View(), pkt.Data
-
- // Where possible, the first fragment that is sent has the same
- // number of bytes reserved for header as the input packet. The link-layer
- // endpoint may depend on this for looking at, eg, L4 headers.
- transFitsFirst := len(transHeader) <= innerMTU
-
- for i := 0; i < n; i++ {
- reserve := reservedForLink + int(ip.HeaderLength())
- if i == 0 && transFitsFirst {
- // Reserve for transport header if it's going to be put in the first
- // fragment.
- reserve += len(transHeader)
- }
- fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: reserve,
- })
- fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
-
- // Copy data for the fragment.
- avail := innerMTU
-
- if n := len(transHeader); n > 0 {
- if n > avail {
- n = avail
- }
- if i == 0 && transFitsFirst {
- copy(fragPkt.TransportHeader().Push(n), transHeader)
- } else {
- fragPkt.Data.AppendView(transHeader[:n:n])
- }
- transHeader = transHeader[n:]
- avail -= n
- }
-
- if avail > 0 {
- n := data.Size()
- if n > avail {
- n = avail
- }
- data.ReadToVV(&fragPkt.Data, n)
- avail -= n
- }
-
- copied := uint16(innerMTU - avail)
-
- // Set lengths in header and calculate checksum.
- h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip)))
- copy(h, ip)
- if i != n-1 {
- h.SetTotalLength(uint16(outerMTU))
- h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
- } else {
- h.SetTotalLength(uint16(h.HeaderLength()) + copied)
- h.SetFlagsFragmentOffset(flags, offset)
- }
- h.SetChecksum(0)
- h.SetChecksum(^h.CalculateChecksum())
- offset += copied
+// writePacketFragments fragments pkt and writes the results on the link
+// endpoint. The IP header must already present in the original packet. The mtu
+// is the maximum size of the packets.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error {
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
- // Send out the fragment.
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
return err
}
r.Stats().IP.PacketsSent.Increment()
+ if !more {
+ break
+ }
}
+
return nil
}
@@ -303,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -329,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -345,10 +273,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+ if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, e.nic.MTU(), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -377,8 +306,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -392,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -401,8 +333,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -461,9 +394,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return nil
}
+ if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
-
- return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -566,7 +502,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1
+ // A host SHOULD generate Destination Unreachable messages with code:
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported
+ _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -794,14 +736,36 @@ func calculateMTU(mtu uint32) uint32 {
return mtu - header.IPv4MinimumSize
}
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ return mtu
+}
+
+// addressToUint32 translates an IPv4 address into its little endian uint32
+// representation.
+//
+// This function does the same thing as binary.LittleEndian.Uint32 but operates
+// on a tcpip.Address (a string) without the need to convert it to a byte slice,
+// which would cause an allocation.
+func addressToUint32(addr tcpip.Address) uint32 {
+ _ = addr[3] // bounds check hint to compiler
+ return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
+}
+
// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number, and a random initial
-// value (generated once on initialization) to generate the hash.
+// destination address, the transport protocol number and a 32-bit number to
+// generate the hash.
func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- t := r.LocalAddress
- a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- t = r.RemoteAddress
- b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ a := addressToUint32(r.LocalAddress)
+ b := addressToUint32(r.RemoteAddress)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
@@ -821,6 +785,29 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ }
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeaderLength := len(originalIPHeader)
+ nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+
+ if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
}
+
+ flags := originalIPHeader.Flags()
+ if more {
+ flags |= header.IPv4FlagMoreFragments
+ }
+ nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset))
+ nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied))
+ nextFragIPHeader.SetChecksum(0)
+ nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum())
+
+ return fragPkt, more
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 277560e35..9916d783f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -16,19 +16,24 @@ package ipv4_test
import (
"bytes"
+ "context"
"encoding/hex"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -92,6 +97,276 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
+// checks the response.
+func TestIPv4Sanity(t *testing.T) {
+ const (
+ defaultMTU = header.IPv6MinimumMTU
+ ttl = 255
+ nicID = 1
+ randomSequence = 123
+ randomIdent = 42
+ )
+ var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ )
+
+ tests := []struct {
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ shouldFail bool
+ expectICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ options []byte
+ }{
+ {
+ name: "valid",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ },
+ // The TTL tests check that we are not rejecting an incoming packet
+ // with a zero or one TTL, which has been a point of confusion in the
+ // past as RFC 791 says: "If this field contains the value zero, then the
+ // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
+ // for the case of the destination host, stating as follows.
+ //
+ // A host MUST NOT send a datagram with a Time-to-Live (TTL)
+ // value of zero.
+ //
+ // A host MUST NOT discard a datagram just because it was
+ // received with TTL less than 2.
+ {
+ name: "zero TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 0,
+ shouldFail: false,
+ },
+ {
+ name: "one TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ shouldFail: false,
+ },
+ {
+ name: "End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "NOP options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 1, 1},
+ },
+ {
+ name: "NOP and End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 0, 0},
+ },
+ {
+ name: "bad header length",
+ headerLength: header.IPv4MinimumSize - 1,
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (0)",
+ maxTotalLength: 0,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip + icmp - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad protocol",
+ maxTotalLength: defaultMTU,
+ transportProtocol: 99,
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: true,
+ ICMPType: header.ICMPv4DstUnreachable,
+ ICMPCode: header.ICMPv4ProtoUnreachable,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Round up the header size to the next multiple of 4 as RFC 791, page 11
+ // says: "Internet Header Length is the length of the internet header
+ // in 32 bit words..." and on page 23: "The internet header padding is
+ // used to ensure that the internet header ends on a 32 bit boundary."
+ ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1)
+
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+
+ // Specify ident/seq to make sure we get the same in the response.
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ if test.maxTotalLength < totalLen {
+ totalLen = test.maxTotalLength
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHeaderLength),
+ TotalLength: totalLen,
+ Protocol: test.transportProtocol,
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: ipv4Addr.Address,
+ })
+ if n := copy(ip.Options(), test.options); n != len(test.options) {
+ t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options))
+ }
+ // Override the correct value if the test case specified one.
+ if test.headerLength != 0 {
+ ip.SetHeaderLength(test.headerLength)
+ }
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e.Read()
+ if !ok {
+ if test.shouldFail {
+ if test.expectICMP {
+ t.Fatal("expected ICMP error response missing")
+ }
+ return // Expected silent failure.
+ }
+ t.Fatal("expected ICMP echo reply missing")
+ }
+
+ // Check the route that brought the packet to us.
+ if reply.Route.LocalAddress != ipv4Addr.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
+ }
+ if reply.Route.RemoteAddress != remoteIPv4Addr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
+ }
+
+ // Make sure it's all in one buffer.
+ vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
+ replyIPHeader := header.IPv4(vv.ToView())
+
+ // At this stage we only know it's an IP header so verify that much.
+ checker.IPv4(t, replyIPHeader,
+ checker.SrcAddr(ipv4Addr.Address),
+ checker.DstAddr(remoteIPv4Addr),
+ )
+
+ // All expected responses are ICMP packets.
+ if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
+ t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ }
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
+
+ // Sanity check the response.
+ switch replyICMPHeader.Type() {
+ case header.ICMPv4DstUnreachable:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ if !test.shouldFail || !test.expectICMP {
+ t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ return
+ case header.ICMPv4EchoReply:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPv4HeaderLength(ipHeaderLength),
+ checker.IPv4Options(test.options),
+ checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.ICMPv4(
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Seq(randomSequence),
+ checker.ICMPv4Ident(randomIdent),
+ checker.ICMPv4Checksum(),
+ ),
+ )
+ if test.shouldFail {
+ t.Fatalf("unexpected Echo Reply packet\n")
+ }
+ default:
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ }
+ })
+ }
+}
+
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -123,16 +398,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
if got, want := len(ip), int(mtu); got > want {
t.Errorf("fragment is too large, got %d want %d", got, want)
}
- if i == 0 {
- got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size()
- // sourcePacketInfo does not have NetworkHeader added, simulate one.
- want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size()
- // Check that it kept the transport header in packet.TransportHeader if
- // it fits in the first fragment.
- if want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
- }
- }
if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
@@ -162,6 +427,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
func TestFragmentation(t *testing.T) {
+ const ttl = 42
+
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
manyPayloadViewsSizes[i] = 7
@@ -175,15 +442,15 @@ func TestFragmentation(t *testing.T) {
payloadViewsSizes []int
expectedFrags int
}{
- {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
{"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
}
for _, ft := range fragTests {
@@ -194,11 +461,11 @@ func TestFragmentation(t *testing.T) {
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("got err = %s, want = nil", err)
+ t.Fatalf("r.WritePacket(_, _, _) = %s", err)
}
if got := len(ep.WrittenPackets); got != ft.expectedFrags {
@@ -207,6 +474,9 @@ func TestFragmentation(t *testing.T) {
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
@@ -215,36 +485,70 @@ func TestFragmentation(t *testing.T) {
// TestFragmentationErrors checks that errors are returned from write packet
// correctly.
func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ expectedError := tcpip.ErrAborted
fragTests := []struct {
description string
mtu uint32
transportHeaderLength int
- payloadViewsSizes []int
- err *tcpip.Error
+ payloadSize int
allowPackets int
+ fragmentCount int
}{
- {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
- {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
+ {
+ description: "No frag",
+ mtu: 2000,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 1,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 1,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on first frag MTU smaller than header",
+ mtu: 500,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ allowPackets: 0,
+ fragmentCount: 4,
+ },
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.err {
- t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
+ if err != expectedError {
+ t.Errorf("got WritePacket() = %s, want = %s", err, expectedError)
}
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
+ }
})
}
}
@@ -1005,6 +1309,7 @@ func TestReceiveFragments(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
+
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
@@ -1150,12 +1455,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\x10\x00\x00\x02"
)
if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ mask := tcpip.AddressMask(header.IPv4Broadcast)
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1164,7 +1470,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
}
return rt
}
@@ -1188,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
lm.limit--
return false, false
}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
+ TTL: ipv4.DefaultTTL,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a ARP request since link address resolution should be
+ // performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != arp.ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(p.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
+ }
+ }
+
+ // Send an ARP reply to complete link address resolution.
+ {
+ hdr := buffer.View(make([]byte, header.ARPSize))
+ packet := header.ARP(hdr)
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), host2NICLinkAddr)
+ copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
+ copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
+ copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
+ e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 97adbcbd4..a30437f02 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -18,6 +18,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -41,6 +42,7 @@ go_test(
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 4b4b483cc..7be35c78b 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -286,6 +286,17 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
+ // As per RFC 4861 section 7.1.1:
+ // A node MUST silently discard any received Neighbor Solicitation
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP source address is the unspecified address, the IP
+ // destination address is a solicited-node multicast address.
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ received.Invalid.Increment()
+ return
+ }
+
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
// route we end up with here has as its LocalAddress such a
@@ -429,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
localAddr := r.LocalAddress
@@ -445,9 +454,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
defer r.Release()
- // Use the link address from the source of the original packet.
- r.ResolveWith(remoteLinkAddr)
-
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data,
@@ -635,18 +641,21 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- snaddr := header.SolicitedNodeAddr(addr)
-
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
// route here. Note, we would need the nicID to do this properly so the right
// NIC (associated to linkEP) is used to send the NDP NS message.
- r := &stack.Route{
+ r := stack.Route{
LocalAddress: localAddr,
- RemoteAddress: snaddr,
+ RemoteAddress: addr,
RemoteLinkAddress: remoteLinkAddr,
}
+
+ // If a remote address is not already known, then send a multicast
+ // solicitation since multicast addresses have a static mapping to link
+ // addresses.
if len(r.RemoteLinkAddress) == 0 {
- r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
+ r.RemoteAddress = header.SolicitedNodeAddr(addr)
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -672,7 +681,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -690,6 +699,36 @@ type icmpReason interface {
isICMPReason()
}
+// icmpReasonParameterProblem is an error during processing of extension headers
+// or the fixed header defined in RFC 4443 section 3.4.
+type icmpReasonParameterProblem struct {
+ code header.ICMPv6Code
+
+ // respondToMulticast indicates that we are sending a packet that falls under
+ // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondToMulticast bool
+
+ // pointer is defined in the RFC 4443 setion 3.4 which reads:
+ //
+ // Pointer Identifies the octet offset within the invoking packet
+ // where the error was detected.
+ //
+ // The pointer will point beyond the end of the ICMPv6
+ // packet if the field in error is beyond what can fit
+ // in the maximum size of an ICMPv6 error message.
+ pointer uint32
+}
+
+func (*icmpReasonParameterProblem) isICMPReason() {}
+
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
type icmpReasonPortUnreachable struct{}
@@ -698,18 +737,11 @@ func (*icmpReasonPortUnreachable) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- stats := r.Stats().ICMP
- sent := stats.V6PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
- // TODO(b/164522993) There are exceptions to this rule.
+ // There are exceptions to this rule.
// See: point e.3) RFC 4443 section-2.4
//
// (e) An ICMPv6 error message MUST NOT be originated as a result of
@@ -727,7 +759,32 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Section 4.2 of [IPv6]) that has the Option Type highest-
// order two bits set to 10).
//
- if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any {
+ var allowResponseToMulticast bool
+ if reason, ok := reason.(*icmpReasonParameterProblem); ok {
+ allowResponseToMulticast = reason.respondToMulticast
+ }
+
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ stats := p.stack.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
return nil
}
@@ -757,11 +814,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// packet that caused the error) as possible without making
// the error message packet exceed the minimum IPv6 MTU
// [IPv6].
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv6MinimumMTU {
mtu = header.IPv6MinimumMTU
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
available := int(mtu) - headerLen
if available < header.IPv6MinimumSize {
return nil
@@ -780,12 +837,30 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
- counter := sent.DstUnreachable
- err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
- if err != nil {
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ icmpHdr.SetType(header.ICMPv6ParamProblem)
+ icmpHdr.SetCode(reason.code)
+ icmpHdr.SetTypeSpecific(reason.pointer)
+ counter = sent.ParamProblem
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ counter = sent.DstUnreachable
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data))
+ if err := route.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ newPkt,
+ ); err != nil {
sent.Dropped.Increment()
return err
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 31370c1d4..3affcc4e4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,17 +16,21 @@ package ipv6
import (
"context"
+ "net"
"reflect"
"strings"
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -39,6 +43,9 @@ const (
defaultChannelSize = 1
defaultMTU = 65536
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 30 * time.Second
)
var (
@@ -50,6 +57,10 @@ type stubLinkEndpoint struct {
stack.LinkEndpoint
}
+func (*stubLinkEndpoint) MTU() uint32 {
+ return defaultMTU
+}
+
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
// Indicate that resolution for link layer addresses is required to send
// packets over this link. This is needed so the NIC knows to allocate a
@@ -105,7 +116,9 @@ func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
var _ stack.NetworkInterface = (*testInterface)(nil)
-type testInterface struct{}
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
func (*testInterface) ID() tcpip.NICID {
return 0
@@ -123,10 +136,6 @@ func (*testInterface) Enabled() bool {
return true
}
-func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
- return nil
-}
-
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -1219,19 +1228,22 @@ func TestLinkAddressRequest(t *testing.T) {
mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
tests := []struct {
- name string
- remoteLinkAddr tcpip.LinkAddress
- expectLinkAddr tcpip.LinkAddress
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectedLinkAddr tcpip.LinkAddress
+ expectedAddr tcpip.Address
}{
{
- name: "Unicast",
- remoteLinkAddr: linkAddr1,
- expectLinkAddr: linkAddr1,
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectedLinkAddr: linkAddr1,
+ expectedAddr: lladdr0,
},
{
- name: "Multicast",
- remoteLinkAddr: "",
- expectLinkAddr: mcaddr,
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectedLinkAddr: mcaddr,
+ expectedAddr: snaddr,
},
}
@@ -1254,9 +1266,229 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
+ if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr)
+ }
+ if pkt.Route.RemoteAddress != test.expectedAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr)
+ }
+ if pkt.Route.LocalAddress != lladdr1 {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
+ }
+ checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
+ checker.SrcAddr(lladdr1),
+ checker.DstAddr(test.expectedAddr),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
+ ))
+ }
+}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
- if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
}
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a neighbor solicitation since link address resolution should
+ // be performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
+ ))
+ }
+
+ // Send a neighbor advertisement to complete link address resolution.
+ {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index aff4e1425..2bd8f4ece 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,9 +16,12 @@
package ipv6
import (
+ "encoding/binary"
"fmt"
+ "hash/fnv"
"sort"
"sync/atomic"
+ "time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,10 +29,20 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // As per RFC 8200 section 4.5:
+ // If insufficient fragments are received to complete reassembly of a packet
+ // within 60 seconds of the reception of the first-arriving fragment of that
+ // packet, reassembly of that packet must be abandoned.
+ //
+ // Linux also uses 60 seconds for reassembly timeout:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
+ reassembleTimeout = 60 * time.Second
+
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -40,6 +53,9 @@ const (
// DefaultTTL is the default hop limit for IPv6 Packets egressed by
// Netstack.
DefaultTTL = 64
+
+ // buckets for fragment identifiers
+ buckets = 2048
)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
@@ -50,7 +66,6 @@ var _ NDPEndpoint = (*endpoint)(nil)
type endpoint struct {
nic stack.NetworkInterface
- linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
dispatcher stack.TransportDispatcher
@@ -348,21 +363,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
@@ -376,7 +383,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
+}
+
+func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
+ return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet. The mtu is the maximum size of the packets. The transport
+// header protocol number is required to avoid parsing the IPv6 extension
+// headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ if fragMTU < pkt.TransportHeader().View().Size() {
+ // As per RFC 8200 Section 4.5, the Transport Header is expected to be small
+ // enough to fit in the first fragment.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ break
+ }
+ }
+
+ return n, 0, nil
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -402,7 +446,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -423,14 +467,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if e.packetMustBeFragmented(pkt, gso) {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
+
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
r.Stats().IP.PacketsSent.Increment()
return nil
}
-// WritePackets implements stack.LinkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
@@ -441,6 +500,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
+ if e.packetMustBeFragmented(pb, gso) {
+ current := pb
+ _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(current, fragPkt)
+ current = current.Next()
+ return nil
+ })
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ // The fragmented packet can be released. The rest of the packets can be
+ // processed.
+ pkts.Remove(pb)
+ pb = current
+ }
}
// iptables filtering. All packets that reach here are locally
@@ -451,8 +527,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -466,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -475,8 +554,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -536,7 +616,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
- for firstHeader := true; ; firstHeader = false {
+ for {
+ // Keep track of the start of the previous header so we can report the
+ // special case of a Hop by Hop at a location other than at the start.
+ previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -550,11 +633,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
- // (unrecognized next header) error in response to an extension header's
- // Next Header field with the Hop By Hop extension header identifier.
- if !firstHeader {
+ if previousHeaderStart != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ }, pkt)
return
}
@@ -576,13 +659,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
@@ -593,16 +688,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.4, if a node encounters a routing header with
// an unrecognized routing type value, with a non-zero Segments Left
// value, the node must discard the packet and send an ICMP Parameter
- // Problem, Code 0. If the Segments Left is 0, the node must ignore the
- // Routing extension header and process the next header in the packet.
+ // Problem, Code 0 to the packet's Source Address, pointing to the
+ // unrecognized Routing Type.
+ //
+ // If the Segments Left is 0, the node must ignore the Routing extension
+ // header and process the next header in the packet.
//
// Note, the stack does not yet handle any type of routing extension
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
- // unrecognized routing types with a non-zero Segments Left value.
if extHdr.SegmentsLeft() != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
return
}
@@ -737,13 +836,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
@@ -767,8 +878,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ r.Stats().IP.PacketsDelivered.Increment()
switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
@@ -777,18 +887,41 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC 8200 section 4. (page 7):
+ // Extension headers are numbered from IANA IP Protocol Numbers
+ // [IANA-PN], the same values used for IPv4 and IPv6. When
+ // processing a sequence of Next Header values in a packet, the
+ // first one that is not an extension header [IANA-EH] indicates
+ // that the next item in the packet is the corresponding upper-layer
+ // header.
+ // With more related information on page 8:
+ // If, as a result of processing a header, the destination node is
+ // required to proceed to the next header but the Next Header value
+ // in the current header is unrecognized by the node, it should
+ // discard the packet and send an ICMP Parameter Problem message to
+ // the source of the packet, with an ICMP Code value of 1
+ // ("unrecognized Next Header type encountered") and the ICMP
+ // Pointer field containing the offset of the unrecognized value
+ // within the original packet.
+ //
+ // Which when taken together indicate that an unknown protocol should
+ // be treated as an unrecognized next header value.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
}
default:
- // If we receive a packet for an extension header we do not yet handle,
- // drop the packet for now.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
r.Stats().UnknownProtocolRcvdPackets.Increment()
return
}
@@ -1097,6 +1230,9 @@ type protocol struct {
eps map[*endpoint]struct{}
}
+ ids []uint32
+ hashIV uint32
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -1157,7 +1293,6 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
@@ -1318,10 +1453,15 @@ type Options struct {
func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
opts.NDPConfigs.validate()
+ ids := hash.RandN32(buckets)
+ hashIV := hash.RandN32(1)[0]
+
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ ids: ids,
+ hashIV: hashIV,
ndpDisp: opts.NDPDisp,
ndpConfigs: opts.NDPConfigs,
@@ -1339,3 +1479,73 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return NewProtocolWithOptions(Options{})(s)
}
+
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // MTU because they should only be transmitted once.
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ mtu -= header.IPv6FragmentHeaderSize
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
+ return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address and 32-bit number to generate the hash.
+func hashRoute(r *stack.Route, hashIV uint32) uint32 {
+ // The FNV-1a was chosen because it is a fast hashing algorithm, and
+ // cryptographic properties are not needed here.
+ h := fnv.New32a()
+ if _, err := h.Write([]byte(r.LocalAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ if _, err := h.Write([]byte(r.RemoteAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ s := make([]byte, 4)
+ binary.LittleEndian.PutUint32(s, hashIV)
+ if _, err := h.Write(s); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err))
+ }
+
+ return h.Sum32()
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeadersLength := len(originalIPHeaders)
+ fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+ fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+
+ // Copy the IPv6 header and any extension headers already populated.
+ if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
+ }
+ fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
+
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
+ fragmentHeader.Encode(&header.IPv6FragmentFields{
+ M: more,
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ Identification: id,
+ NextHeader: uint8(transportProto),
+ })
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 94344057e..e792ca9e2 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,17 +15,21 @@
package ipv6
import (
+ "encoding/hex"
+ "fmt"
"math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -136,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
}
}
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // sourcePacket does not have its IP Header populated. Let's copy the one
+ // from the first fragment.
+ source := header.IPv6(packets[0].NetworkHeader().View())
+ sourceIPHeadersLen := len(source)
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
+ source = append(source, vv.ToView()...)
+
+ var reassembledPayload buffer.VectorisedView
+ for i, fragment := range packets {
+ // Confirm that the packet is valid.
+ allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
+ fragmentIPHeaders := header.IPv6(allBytes.ToView())
+ if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
+ }
+
+ fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
+ if fragmentIPHeadersLength != sourceIPHeadersLen {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
+ }
+
+ if got := len(fragmentIPHeaders); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
+ }
+
+ sourceIPHeader := source[:header.IPv6MinimumSize]
+ fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
+
+ if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
+ }
+
+ // We expect the IPv6 Header to be similar across each fragment, besides the
+ // payload length.
+ sourceIPHeader.SetPayloadLength(0)
+ fragmentIPHeader.SetPayloadLength(0)
+ if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
+ }
+
+ if len(packets) > 1 {
+ // If the source packet was big enough that it needed fragmentation, let's
+ // inspect the fragment header. Because no other extension headers are
+ // supported, it will always be the last extension header.
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
+
+ if got := fragmentHeader.More(); got != wantFragments[i].more {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
+ }
+ if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
+ }
+ if got := fragmentHeader.NextHeader(); got != uint8(proto) {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
+ }
+ }
+
+ // Store the reassembled payload as we parse each fragment. The payload
+ // includes the Transport header and everything after.
+ reassembledPayload.AppendView(fragment.TransportHeader().View())
+ reassembledPayload.Append(fragment.Data)
+ }
+
+ result := reassembledPayload.ToView()
+ if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+
+ return nil
+}
+
// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
@@ -170,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
protocolFactory stack.TransportProtocolFactory
@@ -195,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
@@ -295,17 +373,22 @@ func TestAddIpv6Address(t *testing.T) {
}
func TestReceiveIPv6ExtHdrs(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
shouldAccept bool
+ // Should we expect an ICMP response and if so, with what contents?
+ expectICMP bool
+ ICMPType header.ICMPv6Type
+ ICMPCode header.ICMPv6Code
+ pointer uint32
+ multicast bool
}{
{
name: "None",
extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
shouldAccept: true,
+ expectICMP: false,
},
{
name: "hopbyhop with unknown option skippable action",
@@ -336,9 +419,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action",
+ name: "hopbyhop with unknown option discard and send icmp action (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -348,12 +452,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -364,39 +474,97 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ multicast: true,
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: true,
},
{
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 1, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6FixedHeaderSize + 2,
},
{
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 0, 0, 0, 0,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
},
{
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
+ expectICMP: false,
},
{
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{},
+ noNextHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
name: "destination with unknown option skippable action",
@@ -412,6 +580,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: true,
+ expectICMP: false,
},
{
name: "destination with unknown option discard action",
@@ -427,9 +596,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "destination with unknown option discard and send icmp action",
+ name: "destination with unknown option discard and send icmp action (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -439,12 +609,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action unless multicast dest",
+ name: "destination with unknown option discard and send icmp action (muilticast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
+ }, destinationExtHdrID
+ },
+ multicast: true,
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -455,22 +651,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "routing - atomic fragment",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ nextHdr, 1,
- // Fragment extension header.
- nextHdr, 0, 0, 0, 1, 2, 3, 4,
- }, routingExtHdrID
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
+ }, destinationExtHdrID
},
- shouldAccept: true,
+ shouldAccept: false,
+ expectICMP: false,
+ multicast: true,
},
{
name: "atomic fragment - routing",
@@ -504,12 +711,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
return []byte{
// Routing extension header.
hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
// Hop By Hop extension header with skippable unknown option.
nextHdr, 0, 62, 4, 1, 2, 3, 4,
}, routingExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
+ name: "routing - hop by hop (with send icmp unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
+
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
},
{
name: "No next header",
@@ -553,6 +790,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
@@ -573,6 +811,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
}
@@ -582,7 +821,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -590,6 +829,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ // Add a default route so that a return packet knows where to go.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -631,12 +878,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
HopLimit: 255,
SrcAddr: addr1,
- DstAddr: addr2,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -650,6 +901,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 0", got)
}
+ if !test.expectICMP {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpected packet received: %#v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected packet wasn't written out")
+ }
+
+ // Pack the output packet into a single buffer.View as the checkers
+ // assume that.
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
+ t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
+ }
+
+ ipHdr := header.IPv6(pkt)
+ checker.IPv6(t, ipHdr, checker.ICMPv6(
+ checker.ICMPv6Type(test.ICMPType),
+ checker.ICMPv6Code(test.ICMPCode)))
+
+ // We know we are looking at no extension headers in the error ICMP
+ // packets.
+ icmpPkt := header.ICMPv6(ipHdr.Payload())
+ // We know we sent small packets that won't be truncated when reflected
+ // back to us.
+ originalPacket := icmpPkt.Payload()
+ if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
+ t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
+ }
+ if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
+ t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
+ }
return
}
@@ -683,7 +972,6 @@ type fragmentData struct {
func TestReceiveIPv6Fragments(t *testing.T) {
const (
- nicID = 1
udpPayload1Length = 256
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
@@ -1815,7 +2103,6 @@ func TestWriteStats(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
-
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1857,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1871,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
}
return rt
}
@@ -1922,3 +2210,293 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) {
}
}
}
+
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
+
+type fragmentationTestCase struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transHdrLen int
+ extraHdrLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ expectedFrags int
+}
+
+var fragmentationTests = []fragmentationTestCase{
+ {
+ description: "No Fragmentation",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso nil",
+ mtu: 1280,
+ gso: nil,
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 176, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 76, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header and prependable bytes",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 20,
+ extraHdrLen: header.IPv6MinimumSize + 66,
+ payloadSize: 1500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 296, more: false},
+ },
+ },
+}
+
+func TestFragmentation(t *testing.T) {
+ const (
+ ttl = 42
+ tos = stack.DefaultTOS
+ transportProto = tcp.ProtocolNumber
+ )
+
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt.Clone()
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != nil {
+ t.Fatalf("WritePacket(_, _, _): = %s", err)
+ }
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+ if len(ep.WrittenPackets) > 0 {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ }
+ })
+ }
+}
+
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ tests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if n != wantTotalPackets || err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from WritePacket
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ tests := []struct {
+ description string
+ mtu uint32
+ transHdrLen int
+ payloadSize int
+ allowPackets int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
+ }{
+ {
+ description: "No frag",
+ mtu: 2000,
+ payloadSize: 1000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 1300,
+ payloadSize: 3000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 1500,
+ payloadSize: 4000,
+ transHdrLen: 0,
+ allowPackets: 1,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on packet with MTU smaller than transport header",
+ mtu: 1280,
+ transHdrLen: 1500,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrMessageTooLong,
+ },
+ }
+
+ for _, ft := range tests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 48a4c65e3..40da011f8 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -1289,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
//
// TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
// LinkEndpoint.LinkAddress) before reaching this point.
- linkAddr := ndp.ep.linkEP.LinkAddress()
+ linkAddr := ndp.ep.nic.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
return false
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 25464a03a..9033a9ed5 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
naDst tcpip.Address
}{
{
- name: "Unspecified source to multicast destination",
+ name: "Unspecified source to solicited-node multicast destination",
nsOpts: nil,
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
@@ -437,11 +437,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: false,
- naSrc: nicAddr,
- naDst: header.IPv6AllNodesMulticastAddress,
+ nsInvalid: true,
},
{
name: "Unspecified source with source ll option to unicast destination",
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index c9e57dc0d..d0ffc299a 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -8,6 +8,7 @@ go_library(
"testutil.go",
],
visibility = [
+ "//pkg/tcpip/network/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
],
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 2eaeab779..eba97334e 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -56,7 +56,6 @@ go_library(
srcs = [
"addressable_endpoint_state.go",
"conntrack.go",
- "forwarder.go",
"headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
@@ -73,6 +72,7 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
+ "pending_packets.go",
"rand.go",
"registration.go",
"route.go",
@@ -123,7 +123,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
@@ -139,7 +138,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "forwarder_test.go",
+ "forwarding_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index db8ac1c2b..4d3acab96 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -679,11 +679,6 @@ type addressState struct {
}
}
-// NetworkEndpoint implements AddressEndpoint.
-func (a *addressState) NetworkEndpoint() NetworkEndpoint {
- return a.addressableEndpointState.networkEndpoint
-}
-
// AddressWithPrefix implements AddressEndpoint.
func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 91ae7dafc..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -196,13 +196,14 @@ type bucket struct {
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
+
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
@@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
@@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
// support cases when they are validated, e.g. when we can't offload
// receive checksumming.
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
@@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
@@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
@@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
// TODO(gvisor.dev/issue/170): Support other transport protocols.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
}
// We only track TCP connections.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return
}
@@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
transProto: header.TCPProtocolNumber,
- netProto: header.IPv4ProtocolNumber,
+ netProto: netProto,
}
conn, _ := ct.connForTID(tid)
if conn == nil {
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarding_test.go
index 4e4b00a92..cf042309e 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -48,10 +48,9 @@ const (
type fwdTestNetworkEndpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
+ nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
- ep LinkEndpoint
}
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
@@ -67,7 +66,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool {
func (*fwdTestNetworkEndpoint) Disable() {}
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
@@ -80,7 +79,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+ return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -99,7 +98,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
@@ -159,10 +158,9 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
e := &fwdTestNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index faa503b00..8d6d9a7f1 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -502,11 +502,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, tcpip.ErrNotConnected
}
- return it.connections.originalDst(epID)
+ return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 08063f6ff..538c4625d 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -34,7 +34,7 @@ func (at *AcceptTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -52,7 +52,7 @@ func (dt *DropTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -76,7 +76,7 @@ func (et *ErrorTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -99,7 +99,7 @@ func (uc *UserChainTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -118,7 +118,7 @@ func (rt *ReturnTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -153,7 +153,7 @@ func (rt *RedirectTarget) ID() TargetID {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -164,11 +164,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1) in Output and
- // to primary address of the incoming interface in Prerouting.
+ // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
- rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ } else {
+ rt.Addr = header.IPv6Loopback
+ }
case Prerouting:
rt.Addr = address
default:
@@ -177,8 +181,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- switch protocol := netHeader.TransportProtocol(); protocol {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
udpHeader.SetDestinationPort(rt.Port)
@@ -186,10 +189,10 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(protocol, length)
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
@@ -198,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- // Change destination address.
- netHeader.SetDestinationAddress(rt.Addr)
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+
+ pkt.Network().SetDestinationAddress(rt.Addr)
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 27e1feec0..4df288798 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -131,10 +131,17 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
- case Reachable, Static:
+ case Stale:
+ entry.handlePacketQueuedLocked()
+ fallthrough
+ case Reachable, Static, Delay, Probe:
+ // As per RFC 4861 section 7.3.3:
+ // "Neighbor Unreachability Detection operates in parallel with the sending
+ // of packets to a neighbor. While reasserting a neighbor's reachability,
+ // a node continues sending packets to that neighbor using the cached
+ // link-layer address."
return entry.neigh, nil, nil
-
- case Unknown, Incomplete, Stale, Delay, Probe:
+ case Unknown, Incomplete:
entry.addWakerLocked(w)
if entry.done == nil {
@@ -147,10 +154,8 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.handlePacketQueuedLocked()
return entry.neigh, entry.done, tcpip.ErrWouldBlock
-
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
-
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index a0b7da5cd..fcd54ed83 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1500,24 +1500,26 @@ func TestNeighborCacheReplace(t *testing.T) {
}
// Verify the entry exists
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
- }
- if t.Failed() {
- t.FailNow()
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- }
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ {
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
}
// Notify of a link address change
@@ -1536,28 +1538,34 @@ func TestNeighborCacheReplace(t *testing.T) {
IsRouter: false,
})
- // Requesting the entry again should start address resolution
+ // Requesting the entry again should start neighbor reachability confirmation.
+ //
+ // Verify the entry's new link address and the new state.
{
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- clock.Advance(config.DelayFirstProbeTime + typicalLatency)
- select {
- case <-doneCh:
- default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
- // Verify the entry's new link address
+ // Verify that the neighbor is now reachable.
{
e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- want = NeighborEntry{
+ want := NeighborEntry{
Addr: entry.Addr,
LocalAddr: entry.LocalAddr,
LinkAddr: updatedLinkAddr,
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 9a72bec79..4d69a4de1 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
@@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index a265fff0a..e79abebca 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -227,8 +227,9 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
- id: entryTestNICID,
- linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+
+ id: entryTestNICID,
stack: &Stack{
clock: clock,
nudDisp: &disp,
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 23022292c..8828cc5fe 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -32,14 +32,18 @@ var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
+ LinkEndpoint
+
stack *Stack
id tcpip.NICID
name string
- linkEP LinkEndpoint
context NICContext
- stats NICStats
- neigh *neighborCache
+ stats NICStats
+ neigh *neighborCache
+
+ // The network endpoints themselves may be modified by calling the interface's
+ // methods, but the map reference and entries must be constant.
networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
@@ -88,10 +92,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
+ LinkEndpoint: ep,
+
stack: stack,
id: id,
name: name,
- linkEP: ep,
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
@@ -127,11 +132,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
- nic.linkEP.Attach(nic)
+ nic.LinkEndpoint.Attach(nic)
return nic
}
+func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
+ return n.networkEndpoints[proto]
+}
+
// Enabled implements NetworkInterface.
func (n *NIC) Enabled() bool {
return atomic.LoadUint32(&n.enabled) == 1
@@ -211,10 +220,9 @@ func (n *NIC) remove() *tcpip.Error {
for _, ep := range n.networkEndpoints {
ep.Close()
}
- n.networkEndpoints = nil
// Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
+ n.LinkEndpoint.Attach(nil)
return nil
}
@@ -234,7 +242,64 @@ func (n *NIC) isPromiscuousMode() bool {
// IsLoopback implements NetworkInterface.
func (n *NIC) IsLoopback() bool {
- return n.linkEP.Capabilities()&CapabilityLoopback != 0
+ return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0
+}
+
+// WritePacket implements NetworkLinkEndpoint.
+func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // As per relevant RFCs, we should queue packets while we wait for link
+ // resolution to complete.
+ //
+ // RFC 1122 section 2.3.2.2 (for IPv4):
+ // The link layer SHOULD save (rather than discard) at least
+ // one (the latest) packet of each set of packets destined to
+ // the same unresolved IP address, and transmit the saved
+ // packet when the address has been resolved.
+ //
+ // RFC 4861 section 5.2 (for IPv6):
+ // Once the IP address of the next-hop node is known, the sender
+ // examines the Neighbor Cache for link-layer information about that
+ // neighbor. If no entry exists, the sender creates one, sets its state
+ // to INCOMPLETE, initiates Address Resolution, and then queues the data
+ // packet pending completion of address resolution.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ r := r.Clone()
+ n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ return nil
+ }
+ return err
+ }
+
+ return n.writePacket(r, gso, protocol, pkt)
+}
+
+func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Size()
+
+ if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
+ return err
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
+}
+
+// WritePackets implements NetworkLinkEndpoint.
+func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
+ // is being peformed like WritePacket.
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
+ n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Size()
+ }
+
+ n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ return writtenPackets, err
}
// setSpoofing enables or disables address spoofing.
@@ -483,9 +548,9 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ defer r.Release()
r.RemoteLinkAddress = remotelinkAddr
- addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
- addressEndpoint.DecRef()
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -519,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// If no local link layer address is provided, assume it was sent
// directly to this NIC.
if local == "" {
- local = n.linkEP.LinkAddress()
+ local = n.LinkEndpoint.LinkAddress()
}
// Are any packet type sockets listening for this network protocol?
@@ -599,11 +664,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n := r.nic
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
- addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
addressEndpoint.DecRef()
r.Release()
return
@@ -614,21 +679,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(b/128629022): move this logic to route.WritePacket.
// TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
- // forwarder will release route.
- return
- }
+
+ // pkt may have set its header and may not have enough headroom for
+ // link-layer header for the other link to prepend. Here we create a new
+ // packet to forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- r.Release()
- return
}
- // The link-address resolution finished immediately.
- n.forwardPacket(&r, protocol, pkt)
r.Release()
return
}
@@ -652,43 +717,18 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
- n.linkEP.AddHeader(local, remote, protocol, p)
+ n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
-
- // pkt may have set its header and may not have enough headroom for link-layer
- // header for the other link to prepend. Here we create a new packet to
- // forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- })
-
- // WritePacket takes ownership of fwdPkt, calculate numBytes first.
- numBytes := fwdPkt.Size()
-
- if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return
- }
-
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
- // TODO(gvisor.dev/issue/4365): Let the caller know that the transport
- // protocol is unrecognized.
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
- return TransportPacketHandled
+ return TransportPacketProtocolUnreachable
}
transProto := state.proto
@@ -792,11 +832,6 @@ func (n *NIC) Name() string {
return n.name
}
-// LinkEndpoint implements NetworkInterface.
-func (n *NIC) LinkEndpoint() LinkEndpoint {
- return n.linkEP
-}
-
// nudConfigs gets the NUD configurations for n.
func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
if n.neigh == nil {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index fdd49b77f..97a96af62 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -33,8 +33,7 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
type testIPv6Endpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
- linkEP LinkEndpoint
+ nic NetworkInterface
protocol *testIPv6Protocol
invalidatedRtr tcpip.Address
@@ -57,12 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 {
// MTU implements NetworkEndpoint.MTU.
func (e *testIPv6Endpoint) MTU() uint32 {
- return e.linkEP.MTU() - header.IPv6MinimumSize
+ return e.nic.MTU() - header.IPv6MinimumSize
}
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
// WritePacket implements NetworkEndpoint.WritePacket.
@@ -134,8 +133,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint implements NetworkProtocol.NewEndpoint.
func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
e := &testIPv6Endpoint{
- nicID: nic.ID(),
- linkEP: nic.LinkEndpoint(),
+ nic: nic,
protocol: p,
}
e.AddressableEndpointState.Init(e)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index a7d9d59fa..105583c49 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
type headerType int
@@ -255,6 +256,20 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
return newPk
}
+// Network returns the network header as a header.Network.
+//
+// Network should only be called when NetworkHeader has been set.
+func (pk *PacketBuffer) Network() header.Network {
+ switch netProto := pk.NetworkProtocolNumber; netProto {
+ case header.IPv4ProtocolNumber:
+ return header.IPv4(pk.NetworkHeader().View())
+ case header.IPv6ProtocolNumber:
+ return header.IPv6(pk.NetworkHeader().View())
+ default:
+ panic(fmt.Sprintf("unknown network protocol number %d", netProto))
+ }
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/pending_packets.go
index 3eff141e6..f838eda8d 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -29,60 +29,60 @@ const (
)
type pendingPacket struct {
- nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
pkt *PacketBuffer
}
-type forwardQueue struct {
+// packetsPendingLinkResolution is a queue of packets pending link resolution.
+//
+// Once link resolution completes successfully, the packets will be written.
+type packetsPendingLinkResolution struct {
sync.Mutex
// The packets to send once the resolver completes.
- packets map[<-chan struct{}][]*pendingPacket
+ packets map[<-chan struct{}][]pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
cancelChans []chan struct{}
}
-func newForwardQueue() *forwardQueue {
- return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+func (f *packetsPendingLinkResolution) init() {
+ f.Lock()
+ defer f.Unlock()
+ f.packets = make(map[<-chan struct{}][]pendingPacket)
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- shouldWait := false
-
+func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
f.Lock()
+ defer f.Unlock()
+
packets, ok := f.packets[ch]
- if !ok {
- shouldWait = true
- }
- for len(packets) == maxPendingPacketsPerResolution {
+ if len(packets) == maxPendingPacketsPerResolution {
p := packets[0]
+ packets[0] = pendingPacket{}
packets = packets[1:]
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
p.route.Release()
}
+
if l := len(packets); l >= maxPendingPacketsPerResolution {
panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
}
- f.packets[ch] = append(packets, &pendingPacket{
- nic: n,
+
+ f.packets[ch] = append(packets, pendingPacket{
route: r,
- proto: protocol,
+ proto: proto,
pkt: pkt,
})
- f.Unlock()
- if !shouldWait {
+ if ok {
return
}
// Wait for the link-address resolution to complete.
- // Start a goroutine with a forwarding-cancel channel so that we can
- // limit the maximum number of goroutines running concurrently.
- cancel := f.newCancelChannel()
+ cancel := f.newCancelChannelLocked()
go func() {
cancelled := false
select {
@@ -92,17 +92,21 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
}
f.Lock()
- packets := f.packets[ch]
+ packets, ok := f.packets[ch]
delete(f.packets, ch)
f.Unlock()
+ if !ok {
+ panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
+ }
+
for _, p := range packets {
if cancelled {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else if _, err := p.route.Resolve(nil); err != nil {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
@@ -112,12 +116,10 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
// newCancelChannel creates a channel that can cancel a pending forwarding
// activity. The oldest channel is closed if the number of open channels would
// exceed maxPendingResolutions.
-func (f *forwardQueue) newCancelChannel() chan struct{} {
- f.Lock()
- defer f.Unlock()
-
+func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
if len(f.cancelChans) == maxPendingResolutions {
ch := f.cancelChans[0]
+ f.cancelChans[0] = nil
f.cancelChans = f.cancelChans[1:]
close(ch)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index b6f823b54..defb9129b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -208,6 +208,10 @@ const (
// transport layer and callers need not take any further action.
TransportPacketHandled TransportPacketDisposition = iota
+ // TransportPacketProtocolUnreachable indicates that the transport
+ // protocol requested in the packet is not supported.
+ TransportPacketProtocolUnreachable
+
// TransportPacketDestinationPortUnreachable indicates that there weren't any
// listeners interested in the packet and the transport protocol has no means
// to notify the sender.
@@ -322,10 +326,6 @@ const (
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
- // NetworkEndpoint returns the NetworkEndpoint the receiver is associated
- // with.
- NetworkEndpoint() NetworkEndpoint
-
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
@@ -475,6 +475,8 @@ type NDPEndpoint interface {
// NetworkInterface is a network interface.
type NetworkInterface interface {
+ NetworkLinkEndpoint
+
// ID returns the interface's ID.
ID() tcpip.NICID
@@ -488,9 +490,6 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
-
- // LinkEndpoint returns the link endpoint backing the interface.
- LinkEndpoint() LinkEndpoint
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -663,22 +662,15 @@ const (
CapabilitySoftwareGSO
)
-// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
-// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint. When a link header exists,
-// it sets each PacketBuffer's LinkHeader field before passing it up the
-// stack.
-type LinkEndpoint interface {
+// NetworkLinkEndpoint is a data-link layer that supports sending network
+// layer packets.
+type NetworkLinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
// physical network doesn't exist, the limit is generally 64k, which
// includes the maximum size of an IP packet.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -686,7 +678,7 @@ type LinkEndpoint interface {
MaxHeaderLength() uint16
// LinkAddress returns the link address (typically a MAC) of the
- // link endpoint.
+ // endpoint.
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
@@ -706,6 +698,19 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+}
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
+type LinkEndpoint interface {
+ NetworkLinkEndpoint
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header. It takes ownership of vv.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 5ade3c832..25f80c1f8 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -72,21 +72,20 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
- linkEP := nic.LinkEndpoint()
r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: linkEP.LinkAddress(),
+ LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
RemoteAddress: remoteAddr,
addressEndpoint: addressEndpoint,
nic: nic,
Loop: loop,
}
- if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = nic.stack
+ r.linkCache = r.nic.stack
}
}
@@ -100,7 +99,7 @@ func (r *Route) NICID() tcpip.NICID {
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength()
+ return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
@@ -116,23 +115,17 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint().Capabilities()
+ return r.nic.LinkEndpoint.Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok {
+ if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
}
-// ResolveWith immediately resolves a route with the specified remote link
-// address.
-func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
-}
-
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -208,17 +201,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
return tcpip.ErrInvalidEndpointState
}
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Size()
-
- err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- }
- return err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
@@ -228,22 +211,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
- // WritePackets takes ownership of pkt, calculate length first.
- numPkts := pkts.Len()
-
- n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
- }
- r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
-
- writtenBytes := 0
- for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Size()
- }
-
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
- return n, err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -253,32 +221,17 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Data.Size()
-
- if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- return nil
+ return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.addressEndpoint.NetworkEndpoint().DefaultTTL()
+ return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.addressEndpoint.NetworkEndpoint().MTU()
-}
-
-// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
-// network endpoint.
-func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber()
+ return r.nic.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 57d8e79e0..3a07577c8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -436,9 +436,9 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // forwarder holds the packets that wait for their link-address resolutions
- // to complete, and forwards them when each resolution is done.
- forwarder *forwardQueue
+ // linkResQueue holds packets that are waiting for link resolution to
+ // complete.
+ linkResQueue packetsPendingLinkResolution
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
@@ -550,8 +550,8 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
+func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := t.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
netProto = header.IPv4ProtocolNumber
@@ -565,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
}
- switch len(e.ID.LocalAddress) {
+ switch len(t.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
@@ -577,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
switch {
- case netProto == e.NetProto:
- case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ case netProto == t.NetProto:
+ case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
if v6only {
return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
}
@@ -640,7 +640,6 @@ func New(opts Options) *Stack {
useNeighborCache: opts.UseNeighborCache,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
@@ -653,6 +652,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
+ s.linkResQueue.init()
// Add specified network protocols.
for _, netProtoFactory := range opts.NetworkProtocols {
@@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
-// GetNICByName gets the NIC specified by name.
-func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+// GetLinkEndpointByName gets the link endpoint specified by name.
+func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
s.mu.RLock()
defer s.mu.RUnlock()
for _, nic := range s.nics {
if nic.Name() == name {
- return nic, true
+ return nic.LinkEndpoint
}
}
- return nil, false
+ return nil
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
}
nics[id] = NICInfo{
Name: nic.name,
- LinkAddress: nic.linkEP.LinkAddress(),
+ LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
- MTU: nic.linkEP.MTU(),
+ MTU: nic.LinkEndpoint.MTU(),
Stats: nic.stats,
Context: nic.context,
- ARPHardwareType: nic.linkEP.ARPHardwareType(),
+ ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
}
}
return nics
@@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1539,7 +1539,7 @@ func (s *Stack) Wait() {
s.mu.RLock()
defer s.mu.RUnlock()
for _, n := range s.nics {
- n.linkEP.Wait()
+ n.LinkEndpoint.Wait()
}
}
@@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
// Add our own fake ethernet header.
ethFields := header.EthernetFields{
- SrcAddr: nic.linkEP.LinkAddress(),
+ SrcAddr: nic.LinkEndpoint.LinkAddress(),
DstAddr: dst,
Type: netProto,
}
@@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
vv := buffer.View(fakeHeader).ToVectorisedView()
vv.Append(payload)
- if err := nic.linkEP.WriteRawPacket(vv); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
return err
}
@@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView)
return tcpip.ErrUnknownDevice
}
- if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
return err
}
@@ -1796,7 +1796,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco
return nil, tcpip.ErrUnknownNICID
}
- return nic.networkEndpoints[proto], nil
+ return nic.getNetworkEndpoint(proto), nil
}
// NUDConfigurations gets the per-interface NUD configurations.
@@ -1873,10 +1873,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
if addressEndpoint == nil {
continue
}
-
- ep := addressEndpoint.NetworkEndpoint()
addressEndpoint.DecRef()
- return ep, nil
+ return nic.getNetworkEndpoint(netProto), nil
}
return nil, tcpip.ErrBadAddress
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index aa20f750b..38994cca1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,7 +21,6 @@ import (
"bytes"
"fmt"
"math"
- "net"
"sort"
"testing"
"time"
@@ -35,7 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -77,10 +75,9 @@ type fakeNetworkEndpoint struct {
enabled bool
}
- nicID tcpip.NICID
+ nic stack.NetworkInterface
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- ep stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
@@ -103,7 +100,7 @@ func (f *fakeNetworkEndpoint) Disable() {
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
@@ -135,7 +132,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fakeNetHeaderLen
+ return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -164,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -216,10 +213,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &fakeNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
@@ -2106,7 +2102,7 @@ func TestNICStats(t *testing.T) {
t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
+ if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -3502,52 +3498,6 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
}
-func TestResolveWith(t *testing.T) {
- const (
- unspecifiedNICID = 0
- nicID = 1
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- })
- ep := channel.New(0, defaultMTU, "")
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
-
- remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
- r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
- }
- defer r.Release()
-
- // Should initially require resolution.
- if !r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = false, want = true")
- }
-
- // Manually resolving the route should no longer require resolution.
- r.ResolveWith("\x01")
- if r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = true, want = false")
- }
-}
-
// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
// associated address is removed should not cause a panic.
func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 06c7a3cd3..a4f141253 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -6,6 +6,8 @@ go_test(
name = "integration_test",
size = "small",
srcs = [
+ "forward_test.go",
+ "link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
],
@@ -15,6 +17,8 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/pipe",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
new file mode 100644
index 000000000..ffd38ee1a
--- /dev/null
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -0,0 +1,378 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestForwarding(t *testing.T) {
+ const (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
+ routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1NICID = 1
+ routerNICID1 = 2
+ routerNICID2 = 3
+ host2NICID = 4
+
+ listenPort = 8080
+ )
+
+ host1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC1IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ routerNIC2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host2IPv4Addr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC1IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ routerNIC2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("b::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+
+ type endpointAndAddresses struct {
+ serverEP tcpip.Endpoint
+ serverAddr tcpip.Address
+ serverReadableCH chan struct{}
+
+ clientEP tcpip.Endpoint
+ clientAddr tcpip.Address
+ clientReadableCH chan struct{}
+ }
+
+ newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) {
+ t.Helper()
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(transProto, netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err)
+ }
+
+ t.Cleanup(func() {
+ wq.EventUnregister(&we)
+ })
+
+ return ep, ch
+ }
+
+ tests := []struct {
+ name string
+ epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses
+ }{
+ {
+ name: "IPv4 host1 server with host2 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ {
+ name: "IPv6 host2 server with host1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired)
+ routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ })
+ routerStack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ tcpip.Route{
+ Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ })
+
+ epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack)
+ defer epsAndAddrs.serverEP.Close()
+ defer epsAndAddrs.clientEP.Close()
+
+ serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort}
+ if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+ clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr}
+ if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil {
+ t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
+ }
+
+ write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) {
+ t.Helper()
+
+ dataPayload := tcpip.SlicePayload(data)
+ wOpts := tcpip.WriteOptions{To: to}
+ n, ch, err := ep.Write(dataPayload, wOpts)
+ if err == tcpip.ErrNoLinkAddress {
+ // Wait for link resolution to complete.
+ <-ch
+
+ n, _, err = ep.Write(dataPayload, wOpts)
+ } else if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+
+ if err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ }
+ if want := int64(len(data)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+ }
+
+ data := []byte{1, 2, 3, 4}
+ write(epsAndAddrs.clientEP, data, &serverAddr)
+
+ read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress {
+ t.Helper()
+
+ // Wait for the endpoint to be readable.
+ <-ch
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+
+ if diff := cmp.Diff(v, buffer.View(data)); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != expectedFrom {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ return addr
+ }
+
+ addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr)
+ // Unspecify the NIC since NIC IDs are meaningless across stacks.
+ addr.NIC = 0
+
+ data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ write(epsAndAddrs.serverEP, data, &addr)
+ addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr)
+ if addr.Port != listenPort {
+ t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
new file mode 100644
index 000000000..bf3a6f6ee
--- /dev/null
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -0,0 +1,219 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/pipe"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
+ }
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+)
+
+// TestPing tests that two hosts can ping eachother when link resolution is
+// enabled.
+func TestPing(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+
+ // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
+ // request/reply packets.
+ icmpDataOffset = 8
+ )
+
+ tests := []struct {
+ name string
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ icmpBuf func(*testing.T) buffer.View
+ }{
+ {
+ name: "IPv4 Ping",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv4Echo)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ {
+ name: "IPv6 Ping",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ icmpBuf: func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv6EchoRequest)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ }
+
+ host1Stack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+
+ host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+
+ if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ tcpip.Route{
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ })
+
+ var wq waiter.Queue
+ we, waiterCH := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
+
+ // The first write should trigger link resolution.
+ icmpBuf := test.icmpBuf(t)
+ wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
+ if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress)
+ } else {
+ // Wait for link resolution to complete.
+ <-ch
+ }
+ if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ t.Fatalf("ep.Write(_, _): %s", err)
+ } else if want := int64(len(icmpBuf)); n != want {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want)
+ }
+
+ // Wait for the endpoint to be readable.
+ <-waiterCH
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+ if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != test.remoteAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 72d86b5ab..4f2ca7f54 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -203,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
}
- src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View())
+ src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
if src != expectedSrc {
t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6891fd245..0aaef495d 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
pkt.Owner = owner
pkt.EgressRoute = r
pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
+ pkt.NetworkProtocolNumber = r.NetProto
data.ReadToVV(&pkt.Data, packetSize)
buildTCPHdr(r, tf, pkt, gso)
tf.seq = tf.seq.Add(seqnum.Size(packetSize))
@@ -1219,12 +1219,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
return true, nil
}
- // Increase counter if after processing the segment we would potentially
- // advertise a zero window.
- if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above {
- e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- }
-
// Now check if the received segment has caused us to transition
// to a CLOSED state, if yes then terminate processing and do
// not invoke the sender.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7ad894840..3bcd3923a 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -248,6 +248,11 @@ type ReceiveErrors struct {
// ZeroRcvWindowState is the number of times we advertised
// a zero receive window when rcvList is full.
ZeroRcvWindowState tcpip.StatCounter
+
+ // WantZeroWindow is the number of times we wanted to advertise a
+ // zero receive window but couldn't because it would have caused
+ // the receive window's right edge to shrink.
+ WantZeroRcvWindow tcpip.StatCounter
}
// SendErrors collect segment send errors within the transport layer.
@@ -1162,7 +1167,7 @@ func (e *endpoint) cleanupLocked() {
// wndFromSpace returns the window that we can advertise based on the available
// receive buffer space.
func wndFromSpace(space int) int {
- return space / (1 << rcvAdvWndScale)
+ return space >> rcvAdvWndScale
}
// initialReceiveWindow returns the initial receive window to advertise in the
@@ -1518,6 +1523,38 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
+// selectWindowLocked returns the new window without checking for shrinking or scaling
+// applied.
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) {
+ wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked())
+ maxWindow := wndFromSpace(e.rcvBufSize)
+ wndFromUsedBytes := maxWindow - e.rcvBufUsed
+
+ // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in
+ // cases where we receive a lot of small segments the segment overhead is a
+ // lot higher and we can run out socket buffer space before we can fill the
+ // previous window we advertised. In cases where we receive MSS sized or close
+ // MSS sized segments we will probably run out of window space before we
+ // exhaust receive buffer.
+ newWnd := wndFromAvailable
+ if newWnd > wndFromUsedBytes {
+ newWnd = wndFromUsedBytes
+ }
+ if newWnd < 0 {
+ newWnd = 0
+ }
+ return seqnum.Size(newWnd)
+}
+
+// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu.
+func (e *endpoint) selectWindow() (wnd seqnum.Size) {
+ e.rcvListMu.Lock()
+ wnd = e.selectWindowLocked()
+ e.rcvListMu.Unlock()
+ return wnd
+}
+
// windowCrossedACKThresholdLocked checks if the receive window to be announced
// would be under aMSS or under the window derived from half receive buffer,
// whichever smaller. This is useful as a receive side silly window syndrome
@@ -1534,7 +1571,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// Precondition: e.mu and e.rcvListMu must be held.
func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
- newAvail := wndFromSpace(e.receiveBufferAvailableLocked())
+ newAvail := int(e.selectWindowLocked())
oldAvail := newAvail - deltaBefore
if oldAvail < 0 {
oldAvail = 0
@@ -2099,7 +2136,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
case *tcpip.OriginalDestinationOption:
e.LockUser()
ipt := e.stack.IPTables()
- addr, port, err := ipt.OriginalDst(e.ID)
+ addr, port, err := ipt.OriginalDst(e.ID, e.NetProto)
e.UnlockUser()
if err != nil {
return err
@@ -3013,6 +3050,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
EndSequence: rc.endSequence,
FACK: rc.fack,
RTT: rc.rtt,
+ Reord: rc.reorderSeen,
}
return s
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index d969ca23a..d312b1b8b 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -29,26 +29,36 @@ import (
//
// +stateify savable
type rackControl struct {
- // xmitTime is the latest transmission timestamp of rackControl.seg.
- xmitTime time.Time `state:".(unixTime)"`
-
// endSequence is the ending TCP sequence number of rackControl.seg.
endSequence seqnum.Value
+ // dsack indicates if the connection has seen a DSACK.
+ dsack bool
+
// fack is the highest selectively or cumulatively acknowledged
// sequence.
fack seqnum.Value
+ // minRTT is the estimated minimum RTT of the connection.
+ minRTT time.Duration
+
// rtt is the RTT of the most recently delivered packet on the
// connection (either cumulatively acknowledged or selectively
// acknowledged) that was not marked invalid as a possible spurious
// retransmission.
rtt time.Duration
+
+ // reorderSeen indicates if reordering has been detected on this
+ // connection.
+ reorderSeen bool
+
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
}
-// Update will update the RACK related fields when an ACK has been received.
+// update will update the RACK related fields when an ACK has been received.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
-func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) {
+func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) {
rtt := time.Now().Sub(seg.xmitTime)
// If the ACK is for a retransmitted packet, do not update if it is a
@@ -65,12 +75,21 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration,
return
}
}
- if rtt < srtt {
+ if rtt < rc.minRTT {
return
}
}
rc.rtt = rtt
+
+ // The sender can either track a simple global minimum of all RTT
+ // measurements from the connection, or a windowed min-filtered value
+ // of recent RTT measurements. This implementation keeps track of the
+ // simple global minimum of all RTTs for the connection.
+ if rtt < rc.minRTT || rc.minRTT == 0 {
+ rc.minRTT = rtt
+ }
+
// Update rc.xmitTime and rc.endSequence to the transmit time and
// ending sequence number of the packet which has been acknowledged
// most recently.
@@ -80,3 +99,26 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration,
rc.endSequence = endSeq
}
}
+
+// detectReorder detects if packet reordering has been observed.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// * Step 3: Detect data segment reordering.
+// To detect reordering, the sender looks for original data segments being
+// delivered out of order. To detect such cases, the sender tracks the
+// highest sequence selectively or cumulatively acknowledged in the RACK.fack
+// variable. The name "fack" stands for the most "Forward ACK" (this term is
+// adopted from [FACK]). If a never retransmitted segment that's below
+// RACK.fack is (selectively or cumulatively) acknowledged, it has been
+// delivered out of order. The sender sets RACK.reord to TRUE if such segment
+// is identified.
+func (rc *rackControl) detectReorder(seg *segment) {
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.fack.LessThan(endSeq) {
+ rc.fack = endSeq
+ return
+ }
+
+ if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 {
+ rc.reorderSeen = true
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 48bf196d8..8e0b7c843 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -43,6 +43,9 @@ type receiver struct {
// rcvWnd is the non-scaled receive window last advertised to the peer.
rcvWnd seqnum.Size
+ // rcvWUP is the rcvNxt value at the last window update sent.
+ rcvWUP seqnum.Value
+
rcvWndScale uint8
closed bool
@@ -64,6 +67,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWnd: rcvWnd,
+ rcvWUP: irs + 1,
rcvWndScale: rcvWndScale,
lastRcvdAckTime: time.Now(),
}
@@ -84,34 +88,54 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
}
+// currentWindow returns the available space in the window that was advertised
+// last to our peer.
+func (r *receiver) currentWindow() (curWnd seqnum.Size) {
+ endOfWnd := r.rcvWUP.Add(r.rcvWnd)
+ if endOfWnd.LessThan(r.rcvNxt) {
+ // return 0 if r.rcvNxt is past the end of the previously advertised window.
+ // This can happen because we accept a large segment completely even if
+ // accepting it causes it to partially exceed the advertised window.
+ return 0
+ }
+ return r.rcvNxt.Size(endOfWnd)
+}
+
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- avail := wndFromSpace(r.ep.receiveBufferAvailable())
- if avail == 0 {
- // We have no space available to accept any data, move to zero window
- // state.
- r.rcvWnd = 0
- return r.rcvNxt, 0
- }
-
- acc := r.rcvNxt.Add(seqnum.Size(avail))
- newWnd := r.rcvNxt.Size(acc)
- curWnd := r.rcvNxt.Size(r.rcvAcc)
-
+ newWnd := r.ep.selectWindow()
+ curWnd := r.currentWindow()
// Update rcvAcc only if new window is > previously advertised window. We
// should never shrink the acceptable sequence space once it has been
// advertised the peer. If we shrink the acceptable sequence space then we
// would end up dropping bytes that might already be in flight.
- if newWnd > curWnd {
- r.rcvAcc = r.rcvNxt.Add(newWnd)
+ // ==================================================== sequence space.
+ // ^ ^ ^ ^
+ // rcvWUP rcvNxt rcvAcc new rcvAcc
+ // <=====curWnd ===>
+ // <========= newWnd > curWnd ========= >
+ if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) {
+ // If the new window moves the right edge, then update rcvAcc.
+ r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd))
} else {
+ if newWnd == 0 {
+ // newWnd is zero but we can't advertise a zero as it would cause window
+ // to shrink so just increment a metric to record this event.
+ r.ep.stats.ReceiveErrors.WantZeroRcvWindow.Increment()
+ }
newWnd = curWnd
}
// Stash away the non-scaled receive window as we use it for measuring
// receiver's estimated RTT.
r.rcvWnd = newWnd
- return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
+ r.rcvWUP = r.rcvNxt
+ scaledWnd := r.rcvWnd >> r.rcvWndScale
+ if scaledWnd == 0 {
+ // Increment a metric if we are advertising an actual zero window.
+ r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
+ }
+ return r.rcvNxt, scaledWnd
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 13acaf753..1f9c5cf50 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -71,6 +71,9 @@ type segment struct {
// xmitTime is the last transmit time of this segment.
xmitTime time.Time `state:".(unixTime)"`
xmitCount uint32
+
+ // acked indicates if the segment has already been SACKed.
+ acked bool
}
func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index c55589c45..6fa8d63cd 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -17,6 +17,7 @@ package tcp
import (
"fmt"
"math"
+ "sort"
"sync/atomic"
"time"
@@ -263,6 +264,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
highRxt: iss,
rescueRxt: iss,
},
+ rc: rackControl{
+ fack: iss,
+ },
gso: ep.gso != nil,
}
@@ -1274,6 +1278,39 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
return true
}
+// Iterate the writeList and update RACK for each segment which is newly acked
+// either cumulatively or selectively. Loop through the segments which are
+// sacked, and update the RACK related variables and check for reordering.
+//
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// steps 2 and 3.
+func (s *sender) walkSACK(rcvdSeg *segment) {
+ // Sort the SACK blocks. The first block is the most recent unacked
+ // block. The following blocks can be in arbitrary order.
+ sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
+ copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks)
+ sort.Slice(sackBlocks, func(i, j int) bool {
+ return sackBlocks[j].Start.LessThan(sackBlocks[i].Start)
+ })
+
+ seg := s.writeList.Front()
+ for _, sb := range sackBlocks {
+ // This check excludes DSACK blocks.
+ if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) {
+ continue
+ }
+
+ for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
+ if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
+ seg.acked = true
+ }
+ seg = seg.Next()
+ }
+ }
+}
+
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
@@ -1308,6 +1345,21 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
rcvdSeg.hasNewSACKInfo = true
}
}
+
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08
+ // section-7.2
+ // * Step 2: Update RACK stats.
+ // If the ACK is not ignored as invalid, update the RACK.rtt
+ // to be the RTT sample calculated using this ACK, and
+ // continue. If this ACK or SACK was for the most recently
+ // sent packet, then record the RACK.xmit_ts timestamp and
+ // RACK.end_seq sequence implied by this ACK.
+ // * Step 3: Detect packet reordering.
+ // If the ACK selectively or cumulatively acknowledges an
+ // unacknowledged and also never retransmitted sequence below
+ // RACK.fack, then the corresponding packet has been
+ // reordered and RACK.reord is set to TRUE.
+ s.walkSACK(rcvdSeg)
s.SetPipe()
}
@@ -1365,9 +1417,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
ackLeft := acked
originalOutstanding := s.outstanding
- s.rtt.Lock()
- srtt := s.rtt.srtt
- s.rtt.Unlock()
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1388,13 +1437,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// Update the RACK fields if SACK is enabled.
- if s.ep.sackPermitted {
- s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset)
+ if s.ep.sackPermitted && !seg.acked {
+ s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.detectReorder(seg)
}
s.writeList.Remove(seg)
- // if SACK is enabled then Only reduce outstanding if
+ // If SACK is enabled then Only reduce outstanding if
// the segment was not previously SACKED as these have
// already been accounted for in SetPipe().
if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index e03f101e8..d3f92b48c 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -21,17 +21,20 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
)
+const (
+ maxPayload = 10
+ tsOptionSize = 12
+ maxTCPOptionSize = 40
+)
+
// TestRACKUpdate tests the RACK related fields are updated when an ACK is
// received on a SACK enabled connection.
func TestRACKUpdate(t *testing.T) {
- const maxPayload = 10
- const tsOptionSize = 12
- const maxTCPOptionSize = 40
-
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
defer c.Cleanup()
@@ -49,7 +52,7 @@ func TestRACKUpdate(t *testing.T) {
}
if state.Sender.RACKState.RTT == 0 {
- t.Fatalf("RACK RTT failed to update when an ACK is received")
+ t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0")
}
})
setStackSACKPermitted(t, c, true)
@@ -69,6 +72,66 @@ func TestRACKUpdate(t *testing.T) {
bytesRead := 0
c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
bytesRead += maxPayload
- c.SendAck(790, bytesRead)
+ c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead)
time.Sleep(200 * time.Millisecond)
}
+
+// TestRACKDetectReorder tests that RACK detects packet reordering.
+func TestRACKDetectReorder(t *testing.T) {
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ const ackNum = 2
+
+ var n int
+ ch := make(chan struct{})
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ gotSeq := state.Sender.RACKState.FACK
+ wantSeq := state.Sender.SndNxt
+ // FACK should be updated to the highest ending sequence number of the
+ // segment acknowledged most recently.
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ n++
+ if n < ackNum {
+ if state.Sender.RACKState.Reord {
+ t.Fatalf("RACK reorder detected when there is no reordering")
+ }
+ return
+ }
+
+ if state.Sender.RACKState.Reord == false {
+ t.Fatalf("RACK reorder detection failed")
+ }
+ close(ch)
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+ data := buffer.NewView(ackNum * maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ for i := 0; i < ackNum; i++ {
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ }
+
+ start := c.IRS.Add(maxPayload + 1)
+ end := start.Add(maxPayload)
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}})
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ <-ch
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 5b504d0d1..a7149efd0 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -6264,14 +6264,27 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
rawEP.NextSeqNum--
rawEP.SendPacketWithTS(nil, tsVal)
rawEP.NextSeqNum++
+
if i == 0 {
// In the first iteration the receiver based RTT is not
// yet known as a result the moderation code should not
// increase the advertised window.
rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd))
} else {
- pkt := c.GetPacket()
- curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale
+ // Read loop above could generate an ACK if the window had dropped to
+ // zero and then read had opened it up.
+ lastACK := c.GetPacket()
+ // Discard any intermediate ACKs and only check the last ACK we get in a
+ // short time period of few ms.
+ for {
+ time.Sleep(1 * time.Millisecond)
+ pkt := c.GetPacketNonBlocking()
+ if pkt == nil {
+ break
+ }
+ lastACK = pkt
+ }
+ curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale
// If thew new current window is close maxReceiveBufferSize then terminate
// the loop. This can happen before all iterations are done due to timing
// differences when running the test.
@@ -7328,7 +7341,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Write chunks of ~30000 bytes. It's important that two
// payloads make it equal or longer than MSS.
- remain := rcvBuf * 2
+ remain := rcvBuf
sent := 0
data := make([]byte, defaultMTU/2)
@@ -7343,7 +7356,6 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
})
sent += len(data)
remain -= len(data)
-
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index faf51ef95..4d7847142 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -68,9 +68,9 @@ const (
// V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
- // testInitialSequenceNumber is the initial sequence number sent in packets that
+ // TestInitialSequenceNumber is the initial sequence number sent in packets that
// are sent in response to a SYN or in the initial SYN sent to the stack.
- testInitialSequenceNumber = 789
+ TestInitialSequenceNumber = 789
)
// StackAddrWithPrefix is StackAddr with its associated prefix length.
@@ -505,7 +505,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.TCP(
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -532,7 +532,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.TCP(
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
- checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
@@ -912,7 +912,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
@@ -1084,7 +1084,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
offset += paddingToAdd
// Send a SYN request.
- iss := seqnum.Value(testInitialSequenceNumber)
+ iss := seqnum.Value(TestInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: TestPort,
DstPort: StackPort,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4e14a5fc5..b4604ba35 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1432,7 +1432,9 @@ func TestNoChecksum(t *testing.T) {
var _ stack.NetworkInterface = (*testInterface)(nil)
-type testInterface struct{}
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
func (*testInterface) ID() tcpip.NICID {
return 0
@@ -1450,10 +1452,6 @@ func (*testInterface) Enabled() bool {
return true
}
-func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
- return nil
-}
-
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1780,16 +1778,26 @@ func TestV4UnknownDestination(t *testing.T) {
checker.ICMPv4Type(header.ICMPv4DstUnreachable),
checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ // We need to compare the included data part of the UDP packet that is in
+ // the ICMP packet with the matching original data.
icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
wantLen := len(payload)
if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ // To work out the data size we need to simulate what the sender would
+ // have done. The wanted size is the total available minus the sum of
+ // the headers in the UDP AND ICMP packets, given that we know the test
+ // had only a minimal IP header but the ICMP sender will have allowed
+ // for a maximally sized packet header.
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
+
}
- // In case of large payloads the IP packet may be truncated. Update
+ // In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ // Add back the two headers within the payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload())
if got, want := len(origDgram.Payload()), wantLen; got != want {
@@ -2015,7 +2023,8 @@ func TestPayloadModifiedV4(t *testing.T) {
payload := newPayload()
h := unicastV4.header4Tuple(incoming)
buf := c.buildV4Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
@@ -2045,7 +2054,8 @@ func TestPayloadModifiedV6(t *testing.T) {
payload := newPayload()
h := unicastV6.header4Tuple(incoming)
buf := c.buildV6Packet(payload, &h)
- // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ // Modify the payload so that the checksum value in the UDP header will be
+ // incorrect.
buf[len(buf)-1]++
c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 06fb823f6..49ab87c58 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -270,7 +270,7 @@ func RandomID(prefix string) string {
// same name, sometimes between test runs the socket does not get cleaned up
// quickly enough, causing container creation to fail.
func RandomContainerID() string {
- return RandomID("test-container-")
+ return RandomID("test-container")
}
// Copy copies file from src to dst.