summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/bpf/decoder.go2
-rw-r--r--pkg/merkletree/merkletree.go14
-rw-r--r--pkg/merkletree/merkletree_test.go29
-rw-r--r--pkg/refsvfs2/refs_map.go29
-rw-r--r--pkg/refsvfs2/refs_template.go38
-rw-r--r--pkg/sentry/fs/gofer/path.go3
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go6
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/utils_test.go1
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go4
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go26
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go214
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go4
-rw-r--r--pkg/sentry/fsimpl/host/host.go2
-rw-r--r--pkg/sentry/fsimpl/host/socket.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go54
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go18
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go16
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go4
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go133
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go23
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go2
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go2
-rw-r--r--pkg/sentry/fsimpl/proc/task.go2
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go2
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go2
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go2
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go50
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go869
-rw-r--r--pkg/sentry/hostfd/BUILD2
-rw-r--r--pkg/sentry/hostfd/hostfd_linux.go18
-rw-r--r--pkg/sentry/hostfd/hostfd_unsafe.go9
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go2
-rw-r--r--pkg/sentry/kernel/fs_context.go6
-rw-r--r--pkg/sentry/kernel/ipc_namespace.go2
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go8
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go50
-rw-r--r--pkg/sentry/kernel/pipe/pipe_test.go11
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go4
-rw-r--r--pkg/sentry/kernel/ptrace.go4
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go46
-rw-r--r--pkg/sentry/kernel/sessions.go38
-rw-r--r--pkg/sentry/kernel/shm/shm.go2
-rw-r--r--pkg/sentry/kernel/task_clone.go2
-rw-r--r--pkg/sentry/kernel/task_usermem.go95
-rw-r--r--pkg/sentry/kernel/vdso.go2
-rw-r--r--pkg/sentry/mm/aio_context.go2
-rw-r--r--pkg/sentry/mm/special_mappable.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go17
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go17
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go7
-rw-r--r--pkg/sentry/platform/kvm/kvm.go3
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go3
-rw-r--r--pkg/sentry/platform/kvm/machine.go15
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go31
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go8
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go67
-rw-r--r--pkg/sentry/platform/ring0/defs.go3
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go10
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go7
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s31
-rw-r--r--pkg/sentry/platform/ring0/kernel.go6
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go5
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go4
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go45
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go84
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go21
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go13
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go2
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go63
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go4
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go4
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go52
-rw-r--r--pkg/sentry/socket/netfilter/targets.go217
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go2
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go8
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go2
-rw-r--r--pkg/sentry/socket/unix/unix.go2
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go2
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go32
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go30
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go23
-rw-r--r--pkg/sentry/vfs/file_description.go2
-rw-r--r--pkg/sentry/vfs/filesystem.go2
-rw-r--r--pkg/sentry/vfs/inotify.go2
-rw-r--r--pkg/sentry/vfs/mount.go13
-rw-r--r--pkg/sentry/vfs/mount_test.go35
-rw-r--r--pkg/shim/runsc/BUILD1
-rw-r--r--pkg/shim/runsc/runsc.go12
-rw-r--r--pkg/tcpip/header/ipv6.go6
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go7
-rw-r--r--pkg/tcpip/link/tun/device.go2
-rw-r--r--pkg/tcpip/network/arp/arp.go5
-rw-r--r--pkg/tcpip/network/ip_test.go107
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go56
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go116
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go14
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go61
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go11
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go124
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go14
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go10
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go10
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go4
-rw-r--r--pkg/tcpip/stack/iptables.go75
-rw-r--r--pkg/tcpip/stack/iptables_targets.go102
-rw-r--r--pkg/tcpip/stack/iptables_types.go41
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go21
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go1256
-rw-r--r--pkg/tcpip/stack/nic.go51
-rw-r--r--pkg/tcpip/stack/nic_test.go3
-rw-r--r--pkg/tcpip/stack/packet_buffer.go60
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go36
-rw-r--r--pkg/tcpip/stack/route.go288
-rw-r--r--pkg/tcpip/stack/stack.go308
-rw-r--r--pkg/tcpip/stack/stack_test.go484
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go63
-rw-r--r--pkg/tcpip/stack/transport_test.go33
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go46
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go48
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go2
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go24
-rw-r--r--pkg/tcpip/tests/integration/route_test.go388
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go14
-rw-r--r--pkg/tcpip/transport/tcp/accept.go84
-rw-r--r--pkg/tcpip/transport/tcp/connect.go30
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go7
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go12
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go31
-rw-r--r--pkg/tcpip/transport/tcp/segment.go47
-rw-r--r--pkg/tcpip/transport/tcp/snd.go4
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go95
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go31
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go24
-rw-r--r--pkg/tcpip/transport/udp/protocol.go8
-rw-r--r--pkg/unet/unet_test.go157
157 files changed, 4897 insertions, 2452 deletions
diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go
index 069d0395d..6d1e65cb1 100644
--- a/pkg/bpf/decoder.go
+++ b/pkg/bpf/decoder.go
@@ -109,7 +109,7 @@ func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error {
case B:
w.WriteString("1")
default:
- return fmt.Errorf("Invalid BPF LD size: %v", inst)
+ return fmt.Errorf("invalid BPF LD size: %v", inst)
}
return nil
}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 18457d287..e0a9e56c5 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -147,6 +147,7 @@ func (layout Layout) blockOffset(level int, index int64) int64 {
// meatadata.
type VerityDescriptor struct {
Name string
+ FileSize int64
Mode uint32
UID uint32
GID uint32
@@ -154,7 +155,7 @@ type VerityDescriptor struct {
}
func (d *VerityDescriptor) String() string {
- return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash)
+ return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, d.RootHash)
}
// verify generates a hash from d, and compares it with expected.
@@ -289,6 +290,7 @@ func Generate(params *GenerateParams) ([]byte, error) {
}
descriptor := VerityDescriptor{
Name: params.Name,
+ FileSize: params.Size,
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
@@ -342,6 +344,7 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error {
}
descriptor := VerityDescriptor{
Name: params.Name,
+ FileSize: params.Size,
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
@@ -401,10 +404,11 @@ func Verify(params *VerifyParams) (int64, error) {
}
}
descriptor := VerityDescriptor{
- Name: params.Name,
- Mode: params.Mode,
- UID: params.UID,
- GID: params.GID,
+ Name: params.Name,
+ FileSize: params.Size,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
}
if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil {
return 0, err
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index 0782ca3e7..405204d94 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -185,42 +185,42 @@ func TestGenerate(t *testing.T) {
{
data: bytes.Repeat([]byte{0}, usermem.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
- expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
+ expectedHash: []byte{39, 30, 12, 152, 185, 58, 32, 84, 218, 79, 74, 113, 104, 219, 230, 234, 25, 126, 147, 36, 212, 44, 76, 74, 25, 93, 228, 41, 243, 143, 59, 147},
},
{
data: bytes.Repeat([]byte{0}, usermem.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
- expectedHash: []byte{14, 27, 126, 158, 9, 94, 163, 51, 243, 162, 82, 167, 183, 127, 93, 121, 221, 23, 184, 59, 104, 166, 111, 49, 161, 195, 229, 111, 121, 201, 233, 68, 10, 154, 78, 142, 154, 236, 170, 156, 110, 167, 15, 144, 155, 97, 241, 235, 202, 233, 246, 217, 138, 88, 152, 179, 238, 46, 247, 185, 125, 20, 101, 201},
+ expectedHash: []byte{184, 76, 172, 204, 17, 136, 127, 75, 224, 42, 251, 181, 98, 149, 1, 44, 58, 148, 20, 187, 30, 174, 73, 87, 166, 9, 109, 169, 42, 96, 87, 202, 59, 82, 174, 80, 51, 95, 101, 100, 6, 246, 56, 120, 27, 166, 29, 59, 67, 115, 227, 121, 241, 177, 63, 238, 82, 157, 43, 107, 174, 180, 44, 84},
},
{
data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
- expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
+ expectedHash: []byte{213, 221, 252, 9, 241, 250, 186, 1, 242, 132, 83, 77, 180, 207, 119, 48, 206, 113, 37, 253, 252, 159, 71, 70, 3, 53, 42, 244, 230, 244, 173, 143},
},
{
data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
- expectedHash: []byte{55, 204, 240, 1, 224, 252, 58, 131, 251, 174, 45, 140, 107, 57, 118, 11, 18, 236, 203, 204, 19, 59, 27, 196, 3, 78, 21, 7, 22, 98, 197, 128, 17, 128, 90, 122, 54, 83, 253, 108, 156, 67, 59, 229, 236, 241, 69, 88, 99, 44, 127, 109, 204, 183, 150, 232, 187, 57, 228, 137, 209, 235, 241, 172},
+ expectedHash: []byte{40, 231, 187, 28, 3, 171, 168, 36, 177, 244, 118, 131, 218, 226, 106, 55, 245, 157, 244, 147, 144, 57, 41, 182, 65, 6, 13, 49, 38, 66, 237, 117, 124, 110, 250, 246, 248, 132, 201, 156, 195, 201, 142, 179, 122, 128, 195, 194, 187, 240, 129, 171, 168, 182, 101, 58, 194, 155, 99, 147, 49, 130, 161, 178},
},
{
data: []byte{'a'},
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
- expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
+ expectedHash: []byte{182, 25, 170, 240, 16, 153, 234, 4, 101, 238, 197, 154, 182, 168, 171, 96, 177, 33, 171, 117, 73, 78, 124, 239, 82, 255, 215, 121, 156, 95, 121, 171},
},
{
data: []byte{'a'},
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
- expectedHash: []byte{207, 233, 114, 94, 113, 212, 243, 160, 59, 232, 226, 77, 28, 81, 176, 61, 211, 213, 222, 190, 148, 196, 90, 166, 237, 56, 113, 148, 230, 154, 23, 105, 14, 97, 144, 211, 12, 122, 226, 207, 167, 203, 136, 193, 38, 249, 227, 187, 92, 238, 101, 97, 170, 255, 246, 209, 246, 98, 241, 150, 175, 253, 173, 206},
+ expectedHash: []byte{121, 28, 140, 244, 32, 222, 61, 255, 184, 65, 117, 84, 132, 197, 122, 214, 95, 249, 164, 77, 211, 192, 217, 59, 109, 255, 249, 253, 27, 142, 110, 29, 93, 153, 92, 211, 178, 198, 136, 34, 61, 157, 141, 94, 145, 191, 201, 134, 141, 138, 51, 26, 33, 187, 17, 196, 113, 234, 125, 219, 4, 41, 57, 120},
},
{
data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
- expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
+ expectedHash: []byte{17, 40, 99, 150, 206, 124, 196, 184, 41, 40, 50, 91, 113, 47, 8, 204, 2, 102, 202, 86, 157, 92, 218, 53, 151, 250, 234, 247, 191, 121, 113, 246},
},
{
data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
- expectedHash: []byte{110, 103, 29, 250, 27, 211, 235, 119, 112, 65, 49, 156, 6, 92, 66, 105, 133, 1, 187, 172, 169, 13, 186, 34, 105, 72, 252, 131, 12, 159, 91, 188, 79, 184, 240, 227, 40, 164, 72, 193, 65, 31, 227, 153, 191, 6, 117, 42, 82, 122, 33, 255, 92, 215, 215, 249, 2, 131, 170, 134, 39, 192, 222, 33},
+ expectedHash: []byte{100, 22, 249, 78, 47, 163, 220, 231, 228, 165, 226, 192, 221, 77, 106, 69, 115, 104, 208, 155, 124, 206, 225, 233, 98, 249, 232, 225, 114, 119, 110, 216, 117, 106, 85, 7, 200, 206, 139, 81, 116, 37, 215, 158, 89, 110, 74, 86, 66, 95, 117, 237, 70, 56, 62, 175, 48, 147, 162, 122, 253, 57, 123, 84},
},
}
@@ -275,6 +275,7 @@ func TestVerify(t *testing.T) {
// fail, otherwise Verify should still succeed.
modifyByte int64
modifyName bool
+ modifySize bool
modifyMode bool
modifyUID bool
modifyGID bool
@@ -323,6 +324,15 @@ func TestVerify(t *testing.T) {
modifyName: true,
shouldSucceed: false,
},
+ // Modified size should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifySize: true,
+ shouldSucceed: false,
+ },
// Modified mode should fail verification.
{
dataSize: usermem.PageSize,
@@ -482,6 +492,9 @@ func TestVerify(t *testing.T) {
if tc.modifyName {
verifyParams.Name = defaultName + "abc"
}
+ if tc.modifySize {
+ verifyParams.Size--
+ }
if tc.modifyMode {
verifyParams.Mode = defaultMode + 1
}
diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go
index faf191f39..9fbc5466f 100644
--- a/pkg/refsvfs2/refs_map.go
+++ b/pkg/refsvfs2/refs_map.go
@@ -62,7 +62,9 @@ func Register(obj CheckedObject) {
}
liveObjects[obj] = struct{}{}
liveObjectsMu.Unlock()
- logEvent(obj, "registered")
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, "registered")
+ }
}
}
@@ -75,31 +77,40 @@ func Unregister(obj CheckedObject) {
panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj))
}
delete(liveObjects, obj)
- logEvent(obj, "unregistered")
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, "unregistered")
+ }
}
}
// LogIncRef logs a reference increment.
func LogIncRef(obj CheckedObject, refs int64) {
- logEvent(obj, fmt.Sprintf("IncRef to %d", refs))
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("IncRef to %d", refs))
+ }
}
// LogTryIncRef logs a successful TryIncRef call.
func LogTryIncRef(obj CheckedObject, refs int64) {
- logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs))
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs))
+ }
}
// LogDecRef logs a reference decrement.
func LogDecRef(obj CheckedObject, refs int64) {
- logEvent(obj, fmt.Sprintf("DecRef to %d", refs))
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("DecRef to %d", refs))
+ }
}
// logEvent logs a message for the given reference-counted object.
+//
+// obj.LogRefs() should be checked before calling logEvent, in order to avoid
+// calling any text processing needed to evaluate msg.
func logEvent(obj CheckedObject, msg string) {
- if obj.LogRefs() {
- log.Infof("[%s %p] %s:", obj.RefType(), obj, msg)
- log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack()))
- }
+ log.Infof("[%s %p] %s:", obj.RefType(), obj, msg)
+ log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack()))
}
// DoLeakCheck iterates through the live object map and logs a message for each
diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go
index 8f50b4ee6..f64b6c6ae 100644
--- a/pkg/refsvfs2/refs_template.go
+++ b/pkg/refsvfs2/refs_template.go
@@ -13,10 +13,7 @@
// limitations under the License.
// Package refs_template defines a template that can be used by reference
-// counted objects. The "owner" template parameter is used in log messages to
-// indicate the type of reference-counted object that exhibited a reference
-// leak. As a result, structs that are embedded in other structs should not use
-// this template, since it will make tracking down leaks more difficult.
+// counted objects.
package refs_template
import (
@@ -43,9 +40,6 @@ var obj *T
// Refs implements refs.RefCounter. It keeps a reference count using atomic
// operations and calls the destructor when the count reaches zero.
//
-// Note that the number of references is actually refCount + 1 so that a default
-// zero-value Refs object contains one reference.
-//
// +stateify savable
type Refs struct {
// refCount is composed of two fields:
@@ -58,6 +52,13 @@ type Refs struct {
refCount int64
}
+// InitRefs initializes r with one reference and, if enabled, activates leak
+// checking.
+func (r *Refs) InitRefs() {
+ atomic.StoreInt64(&r.refCount, 1)
+ refsvfs2.Register(r)
+}
+
// RefType implements refsvfs2.CheckedObject.RefType.
func (r *Refs) RefType() string {
return fmt.Sprintf("%T", obj)[1:]
@@ -81,8 +82,7 @@ func (r *Refs) EnableLeakCheck() {
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
func (r *Refs) ReadRefs() int64 {
- // Account for the internal -1 offset on refcounts.
- return atomic.LoadInt64(&r.refCount) + 1
+ return atomic.LoadInt64(&r.refCount)
}
// IncRef implements refs.RefCounter.IncRef.
@@ -90,8 +90,10 @@ func (r *Refs) ReadRefs() int64 {
//go:nosplit
func (r *Refs) IncRef() {
v := atomic.AddInt64(&r.refCount, 1)
- refsvfs2.LogIncRef(r, v+1)
- if v <= 0 {
+ if enableLogging {
+ refsvfs2.LogIncRef(r, v)
+ }
+ if v <= 1 {
panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType()))
}
}
@@ -105,7 +107,7 @@ func (r *Refs) IncRef() {
//go:nosplit
func (r *Refs) TryIncRef() bool {
const speculativeRef = 1 << 32
- if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) < 0 {
+ if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 {
// This object has already been freed.
atomic.AddInt64(&r.refCount, -speculativeRef)
return false
@@ -113,7 +115,9 @@ func (r *Refs) TryIncRef() bool {
// Turn into a real reference.
v := atomic.AddInt64(&r.refCount, -speculativeRef+1)
- refsvfs2.LogTryIncRef(r, v+1)
+ if enableLogging {
+ refsvfs2.LogTryIncRef(r, v)
+ }
return true
}
@@ -131,12 +135,14 @@ func (r *Refs) TryIncRef() bool {
//go:nosplit
func (r *Refs) DecRef(destroy func()) {
v := atomic.AddInt64(&r.refCount, -1)
- refsvfs2.LogDecRef(r, v+1)
+ if enableLogging {
+ refsvfs2.LogDecRef(r, v+1)
+ }
switch {
- case v < -1:
+ case v < 0:
panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType()))
- case v == -1:
+ case v == 0:
refsvfs2.Unregister(r)
// Call the destructor.
if destroy != nil {
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 3c66dc3c2..6b3627813 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
@@ -305,7 +304,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode,
}
// First create a pipe.
- p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize)
// Wrap the fileOps with our Fifo.
iops := &fifo{
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index e555672ad..52061175f 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -86,9 +86,9 @@ func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
}
// GetFile implements fs.InodeOperations.GetFile.
-func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+func (t *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
flags.Pread = true
- return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil
+ return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: t}), nil
}
// +stateify savable
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index fc0498f17..d6c65301c 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -431,9 +431,6 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Continue.
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
-
- default:
- break
}
}
return done, nil
@@ -532,9 +529,6 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
// Write to that memory as usual.
seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
-
- default:
- break
}
}
return done, nil
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 998b697ca..cf4ed5de0 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -336,7 +336,7 @@ type Fifo struct {
// NewFifo creates a new named pipe.
func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
// First create a pipe.
- p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize)
// Build pipe InodeOperations.
iops := pipe.NewInodeOperations(ctx, perms, p)
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index 346cca558..d8c237753 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -110,7 +110,7 @@ func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.Vir
}
root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- root.EnableLeakCheck()
+ root.InitRefs()
var rootD kernfs.Dentry
rootD.InitRoot(&fs.Filesystem, root)
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 5986133e9..95c475a65 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -315,7 +315,7 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F
readPayload.MarshalUnsafe(outBuf[outHdrLen:])
outIOseq := usermem.BytesIOSequence(outBuf)
- n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
+ _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
if err != nil {
t.Fatalf("Write failed :%v", err)
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 6de416da0..cd0eb56e5 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -219,16 +219,12 @@ func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt
}
fuseFD := device.Impl().(*DeviceFD)
-
fs := &filesystem{
devMinor: devMinor,
opts: opts,
conn: conn,
}
-
- fs.VFSFilesystem().IncRef()
fuseFD.fs = fs
-
return fs, nil
}
@@ -288,7 +284,7 @@ func (fs *filesystem) newRoot(ctx context.Context, creds *auth.Credentials, mode
i := &inode{fs: fs, nodeID: 1}
i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- i.EnableLeakCheck()
+ i.InitRefs()
var d kernfs.Dentry
d.InitRoot(&fs.Filesystem, i)
@@ -301,7 +297,7 @@ func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FU
i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
atomic.StoreUint64(&i.size, attr.Size)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- i.EnableLeakCheck()
+ i.InitRefs()
return i
}
diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go
index e1d9e3365..b2f4276b8 100644
--- a/pkg/sentry/fsimpl/fuse/utils_test.go
+++ b/pkg/sentry/fsimpl/fuse/utils_test.go
@@ -72,6 +72,7 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque
if err != nil {
return nil, nil, err
}
+ fs.VFSFilesystem().Init(vfsObj, nil, fs)
return fs.conn, &fuseDev.vfsfd, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index ce1b2a390..3b5927702 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -98,7 +98,9 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
blockSize: usermem.PageSize, // arbitrary
- hostFD: -1,
+ readFD: -1,
+ writeFD: -1,
+ mmapFD: -1,
nlink: uint32(2),
}
refsvfs2.Register(child)
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 57a2ca43c..7ab298019 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Sync implements vfs.FilesystemImpl.Sync.
@@ -372,9 +371,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if len(name) > maxFilenameLen {
return syserror.ENAMETOOLONG
}
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
if parent.isDeleted() {
return syserror.ENOENT
}
@@ -389,6 +385,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if createInSyntheticDir == nil {
return syserror.EPERM
}
@@ -408,6 +407,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil && child.isSynthetic() {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
// The existence of a non-synthetic dentry at name would be inconclusive
// because the file it represents may have been deleted from the remote
// filesystem, so we would need to make an RPC to revalidate the dentry.
@@ -428,6 +430,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
// No cached dentry exists; however, there might still be an existing file
// at name. As above, we attempt the file creation RPC anyway.
if err := createInRemoteDir(parent, name, &ds); err != nil {
@@ -842,7 +847,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
mode: opts.Mode,
kuid: creds.EffectiveKUID,
kgid: creds.EffectiveKGID,
- pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
return nil
}
@@ -1162,18 +1167,21 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving
// Incorporate the fid that was opened by lcreate.
useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
if useRegularFileFD {
+ openFD := int32(-1)
+ if fdobj != nil {
+ openFD = int32(fdobj.Release())
+ }
child.handleMu.Lock()
if vfs.MayReadFileWithOpenFlags(opts.Flags) {
child.readFile = openFile
if fdobj != nil {
- child.hostFD = int32(fdobj.Release())
+ child.readFD = openFD
+ child.mmapFD = openFD
}
- } else if fdobj != nil {
- // Can't use fdobj if it's not readable.
- fdobj.Close()
}
if vfs.MayWriteFileWithOpenFlags(opts.Flags) {
child.writeFile = openFile
+ child.writeFD = openFD
}
child.handleMu.Unlock()
}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 6f82ce61b..53bcc9986 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -548,11 +548,16 @@ func (fs *filesystem) Release(ctx context.Context) {
d.cache.DropAll(mf)
d.dirty.RemoveAll()
d.dataMu.Unlock()
- // Close the host fd if one exists.
- if d.hostFD >= 0 {
- syscall.Close(int(d.hostFD))
- d.hostFD = -1
+ // Close host FDs if they exist.
+ if d.readFD >= 0 {
+ syscall.Close(int(d.readFD))
}
+ if d.writeFD >= 0 && d.readFD != d.writeFD {
+ syscall.Close(int(d.writeFD))
+ }
+ d.readFD = -1
+ d.writeFD = -1
+ d.mmapFD = -1
d.handleMu.Unlock()
}
// There can't be any specialFileFDs still using fs, since each such
@@ -726,15 +731,17 @@ type dentry struct {
// - If this dentry represents a regular file or directory, readFile is the
// p9.File used for reads by all regularFileFDs/directoryFDs representing
- // this dentry.
+ // this dentry, and readFD (if not -1) is a host FD equivalent to readFile
+ // used as a faster alternative.
//
// - If this dentry represents a regular file, writeFile is the p9.File
- // used for writes by all regularFileFDs representing this dentry.
+ // used for writes by all regularFileFDs representing this dentry, and
+ // writeFD (if not -1) is a host FD equivalent to writeFile used as a
+ // faster alternative.
//
- // - If this dentry represents a regular file, hostFD is the host FD used
- // for memory mappings and I/O (when applicable) in preference to readFile
- // and writeFile. hostFD is always readable; if !writeFile.isNil(), it must
- // also be writable. If hostFD is -1, no such host FD is available.
+ // - If this dentry represents a regular file, mmapFD is the host FD used
+ // for memory mappings. If mmapFD is -1, no such FD is available, and the
+ // internal page cache implementation is used for memory mappings instead.
//
// These fields are protected by handleMu.
//
@@ -742,10 +749,17 @@ type dentry struct {
// either p9.File transitions from closed (isNil() == true) to open
// (isNil() == false), it may be mutated with handleMu locked, but cannot
// be closed until the dentry is destroyed.
+ //
+ // readFD and writeFD may or may not be the same file descriptor. mmapFD is
+ // always either -1 or equal to readFD; if !writeFile.isNil() (the file has
+ // been opened for writing), it is additionally either -1 or equal to
+ // writeFD.
handleMu sync.RWMutex `state:"nosave"`
readFile p9file `state:"nosave"`
writeFile p9file `state:"nosave"`
- hostFD int32 `state:"nosave"`
+ readFD int32 `state:"nosave"`
+ writeFD int32 `state:"nosave"`
+ mmapFD int32 `state:"nosave"`
dataMu sync.RWMutex `state:"nosave"`
@@ -829,7 +843,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
uid: uint32(fs.opts.dfltuid),
gid: uint32(fs.opts.dfltgid),
blockSize: usermem.PageSize,
- hostFD: -1,
+ readFD: -1,
+ writeFD: -1,
+ mmapFD: -1,
}
d.pf.dentry = d
if mask.UID {
@@ -1220,7 +1236,9 @@ func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
// d.checkCachingLocked().
r := atomic.AddInt64(&d.refs, 1)
- refsvfs2.LogIncRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogIncRef(d, r)
+ }
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
@@ -1231,7 +1249,9 @@ func (d *dentry) TryIncRef() bool {
return false
}
if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
- refsvfs2.LogTryIncRef(d, r+1)
+ if d.LogRefs() {
+ refsvfs2.LogTryIncRef(d, r+1)
+ }
return true
}
}
@@ -1251,7 +1271,9 @@ func (d *dentry) DecRef(ctx context.Context) {
// responsible for ensuring that d.checkCachingLocked will be called later.
func (d *dentry) decRefNoCaching() int64 {
r := atomic.AddInt64(&d.refs, -1)
- refsvfs2.LogDecRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogDecRef(d, r)
+ }
if r < 0 {
panic("gofer.dentry.decRefNoCaching() called without holding a reference")
}
@@ -1469,10 +1491,15 @@ func (d *dentry) destroyLocked(ctx context.Context) {
}
d.readFile = p9file{}
d.writeFile = p9file{}
- if d.hostFD >= 0 {
- syscall.Close(int(d.hostFD))
- d.hostFD = -1
+ if d.readFD >= 0 {
+ syscall.Close(int(d.readFD))
+ }
+ if d.writeFD >= 0 && d.readFD != d.writeFD {
+ syscall.Close(int(d.writeFD))
}
+ d.readFD = -1
+ d.writeFD = -1
+ d.mmapFD = -1
d.handleMu.Unlock()
if !d.file.isNil() {
@@ -1584,7 +1611,8 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
d.handleMu.RUnlock()
}
- fdToClose := int32(-1)
+ var fdsToCloseArr [2]int32
+ fdsToClose := fdsToCloseArr[:0]
invalidateTranslations := false
d.handleMu.Lock()
if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc {
@@ -1615,56 +1643,88 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
return err
}
- if d.hostFD < 0 && h.fd >= 0 && openReadable && (d.writeFile.isNil() || openWritable) {
- // We have no existing FD, and the new FD meets the requirements
- // for d.hostFD, so start using it.
- d.hostFD = h.fd
- } else if d.hostFD >= 0 && d.writeFile.isNil() && openWritable {
- // We have an existing read-only FD, but the file has just been
- // opened for writing, so we need to start supporting writable memory
- // mappings. This may race with callers of d.pf.FD() using the existing
- // FD, so in most cases we need to delay closing the old FD until after
- // invalidating memmap.Translations that might have observed it.
- if !openReadable || h.fd < 0 {
- // We don't have a read/write FD, so we have no FD that can be
- // used to create writable memory mappings. Switch to using the
- // internal page cache.
- invalidateTranslations = true
- fdToClose = d.hostFD
- d.hostFD = -1
- } else if d.fs.opts.overlayfsStaleRead {
- // We do have a read/write FD, but it may not be coherent with
- // the existing read-only FD, so we must switch to mappings of
- // the new FD in both the application and sentry.
- if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil {
- d.handleMu.Unlock()
- ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err)
- h.close(ctx)
- return err
+ // Update d.readFD and d.writeFD.
+ if h.fd >= 0 {
+ if openReadable && openWritable && (d.readFD < 0 || d.writeFD < 0 || d.readFD != d.writeFD) {
+ // Replace existing FDs with this one.
+ if d.readFD >= 0 {
+ // We already have a readable FD that may be in use by
+ // concurrent callers of d.pf.FD().
+ if d.fs.opts.overlayfsStaleRead {
+ // If overlayfsStaleRead is in effect, then the new FD
+ // may not be coherent with the existing one, so we
+ // have no choice but to switch to mappings of the new
+ // FD in both the application and sentry.
+ if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err)
+ h.close(ctx)
+ return err
+ }
+ fdsToClose = append(fdsToClose, d.readFD)
+ invalidateTranslations = true
+ d.readFD = h.fd
+ } else {
+ // Otherwise, we want to avoid invalidating existing
+ // memmap.Translations (which is expensive); instead, use
+ // dup3 to make the old file descriptor refer to the new
+ // file description, then close the new file descriptor
+ // (which is no longer needed). Racing callers of d.pf.FD()
+ // may use the old or new file description, but this
+ // doesn't matter since they refer to the same file, and
+ // any racing mappings must be read-only.
+ if err := syscall.Dup3(int(h.fd), int(d.readFD), syscall.O_CLOEXEC); err != nil {
+ oldFD := d.readFD
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldFD, err)
+ h.close(ctx)
+ return err
+ }
+ fdsToClose = append(fdsToClose, h.fd)
+ h.fd = d.readFD
+ }
+ } else {
+ d.readFD = h.fd
}
- invalidateTranslations = true
- fdToClose = d.hostFD
- d.hostFD = h.fd
- } else {
- // We do have a read/write FD. To avoid invalidating existing
- // memmap.Translations (which is expensive), use dup3 to make
- // the old file descriptor refer to the new file description,
- // then close the new file descriptor (which is no longer
- // needed). Racing callers of d.pf.FD() may use the old or new
- // file description, but this doesn't matter since they refer
- // to the same file, and any racing mappings must be read-only.
- if err := syscall.Dup3(int(h.fd), int(d.hostFD), syscall.O_CLOEXEC); err != nil {
- oldHostFD := d.hostFD
- d.handleMu.Unlock()
- ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldHostFD, err)
- h.close(ctx)
- return err
+ if d.writeFD != h.fd && d.writeFD >= 0 {
+ fdsToClose = append(fdsToClose, d.writeFD)
}
- fdToClose = h.fd
+ d.writeFD = h.fd
+ d.mmapFD = h.fd
+ } else if openReadable && d.readFD < 0 {
+ d.readFD = h.fd
+ // If the file has not been opened for writing, the new FD may
+ // be used for read-only memory mappings. If the file was
+ // previously opened for reading (without an FD), then existing
+ // translations of the file may use the internal page cache;
+ // invalidate those mappings.
+ if d.writeFile.isNil() {
+ invalidateTranslations = !d.readFile.isNil()
+ d.mmapFD = h.fd
+ }
+ } else if openWritable && d.writeFD < 0 {
+ d.writeFD = h.fd
+ if d.readFD >= 0 {
+ // We have an existing read-only FD, but the file has just
+ // been opened for writing, so we need to start supporting
+ // writable memory mappings. However, the new FD is not
+ // readable, so we have no FD that can be used to create
+ // writable memory mappings. Switch to using the internal
+ // page cache.
+ invalidateTranslations = true
+ d.mmapFD = -1
+ }
+ } else {
+ // The new FD is not useful.
+ fdsToClose = append(fdsToClose, h.fd)
}
- } else {
- // h.fd is not useful.
- fdToClose = h.fd
+ } else if openWritable && d.writeFD < 0 && d.mmapFD >= 0 {
+ // We have an existing read-only FD, but the file has just been
+ // opened for writing, so we need to start supporting writable
+ // memory mappings. However, we have no writable host FD. Switch to
+ // using the internal page cache.
+ invalidateTranslations = true
+ d.mmapFD = -1
}
// Switch to new fids.
@@ -1698,8 +1758,8 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
d.mappings.InvalidateAll(memmap.InvalidateOpts{})
d.mapsMu.Unlock()
}
- if fdToClose >= 0 {
- syscall.Close(int(fdToClose))
+ for _, fd := range fdsToClose {
+ syscall.Close(int(fd))
}
return nil
@@ -1709,7 +1769,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool
func (d *dentry) readHandleLocked() handle {
return handle{
file: d.readFile,
- fd: d.hostFD,
+ fd: d.readFD,
}
}
@@ -1717,7 +1777,7 @@ func (d *dentry) readHandleLocked() handle {
func (d *dentry) writeHandleLocked() handle {
return handle{
file: d.writeFile,
- fd: d.hostFD,
+ fd: d.writeFD,
}
}
@@ -1730,16 +1790,24 @@ func (d *dentry) syncRemoteFile(ctx context.Context) error {
// Preconditions: d.handleMu must be locked.
func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
// If we have a host FD, fsyncing it is likely to be faster than an fsync
- // RPC.
- if d.hostFD >= 0 {
+ // RPC. Prefer syncing write handles over read handles, since some remote
+ // filesystem implementations may not sync changes made through write
+ // handles otherwise.
+ if d.writeFD >= 0 {
ctx.UninterruptibleSleepStart(false)
- err := syscall.Fsync(int(d.hostFD))
+ err := syscall.Fsync(int(d.writeFD))
ctx.UninterruptibleSleepFinish(false)
return err
}
if !d.writeFile.isNil() {
return d.writeFile.fsync(ctx)
}
+ if d.readFD >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(d.readFD))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
if !d.readFile.isNil() {
return d.readFile.fsync(ctx)
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index dc8a890cb..652142ecc 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -326,7 +326,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error)
// dentry.readHandleLocked() without locking dentry.dataMu.
rw.d.handleMu.RLock()
h := rw.d.readHandleLocked()
- if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ if (rw.d.mmapFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
n, err := h.readToBlocksAt(rw.ctx, dsts, rw.off)
rw.d.handleMu.RUnlock()
rw.off += n
@@ -446,7 +446,7 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro
// without locking dentry.dataMu.
rw.d.handleMu.RLock()
h := rw.d.writeHandleLocked()
- if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ if (rw.d.mmapFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
n, err := h.writeFromBlocksAt(rw.ctx, srcs, rw.off)
rw.off += n
rw.d.dataMu.Lock()
@@ -648,7 +648,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
return syserror.ENODEV
}
d.handleMu.RLock()
- haveFD := d.hostFD >= 0
+ haveFD := d.mmapFD >= 0
d.handleMu.RUnlock()
if !haveFD {
return syserror.ENODEV
@@ -669,7 +669,7 @@ func (d *dentry) mayCachePages() bool {
return true
}
d.handleMu.RLock()
- haveFD := d.hostFD >= 0
+ haveFD := d.mmapFD >= 0
d.handleMu.RUnlock()
return haveFD
}
@@ -727,7 +727,7 @@ func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR,
// Translate implements memmap.Mappable.Translate.
func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
d.handleMu.RLock()
- if d.hostFD >= 0 && !d.fs.opts.forcePageCache {
+ if d.mmapFD >= 0 && !d.fs.opts.forcePageCache {
d.handleMu.RUnlock()
mr := optional
if d.fs.opts.limitHostFDTranslation {
@@ -881,7 +881,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef.
//
// dentryPlatformFile is only used when a host FD representing the remote file
-// is available (i.e. dentry.hostFD >= 0), and that FD is used for application
+// is available (i.e. dentry.mmapFD >= 0), and that FD is used for application
// memory mappings (i.e. !filesystem.opts.forcePageCache).
//
// +stateify savable
@@ -892,8 +892,8 @@ type dentryPlatformFile struct {
// by dentry.dataMu.
fdRefs fsutil.FrameRefSet
- // If this dentry represents a regular file, and dentry.hostFD >= 0,
- // hostFileMapper caches mappings of dentry.hostFD.
+ // If this dentry represents a regular file, and dentry.mmapFD >= 0,
+ // hostFileMapper caches mappings of dentry.mmapFD.
hostFileMapper fsutil.HostFileMapper
// hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
@@ -918,12 +918,12 @@ func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) {
func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
- return d.hostFileMapper.MapInternal(fr, int(d.hostFD), at.Write)
+ return d.hostFileMapper.MapInternal(fr, int(d.mmapFD), at.Write)
}
// FD implements memmap.File.FD.
func (d *dentryPlatformFile) FD() int {
d.handleMu.RLock()
defer d.handleMu.RUnlock()
- return int(d.hostFD)
+ return int(d.mmapFD)
}
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
index 17849dcc0..c90071e4e 100644
--- a/pkg/sentry/fsimpl/gofer/save_restore.go
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -139,7 +139,9 @@ func (d *dentry) beforeSave() {
// afterLoad is invoked by stateify.
func (d *dentry) afterLoad() {
- d.hostFD = -1
+ d.readFD = -1
+ d.writeFD = -1
+ d.mmapFD = -1
if atomic.LoadInt64(&d.refs) != -1 {
refsvfs2.Register(d)
}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 39b902a3e..435a21d77 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -126,8 +126,8 @@ func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fil
isTTY: isTTY,
savable: savable,
}
+ i.InitRefs()
i.CachedMappable.Init(hostFD)
- i.EnableLeakCheck()
// If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and
// handle blocking behavior in the sentry.
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
index 8a447e29f..60acc367f 100644
--- a/pkg/sentry/fsimpl/host/socket.go
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -84,6 +84,8 @@ type ConnectedEndpoint struct {
// init performs initialization required for creating new ConnectedEndpoints and
// for restoring them.
func (c *ConnectedEndpoint) init() *syserr.Error {
+ c.InitRefs()
+
family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN)
if err != nil {
return syserr.FromError(err)
@@ -132,7 +134,6 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable
// ConnectedEndpointRefs start off with a single reference. We need two.
e.IncRef()
- e.EnableLeakCheck()
return &e, nil
}
@@ -376,8 +377,7 @@ func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr s
return nil, err
}
- // ConnectedEndpointRefs start off with a single reference. We need two.
+ // e starts off with a single reference. We need two.
e.IncRef()
- e.EnableLeakCheck()
return &e, nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index f81056023..e77523f22 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -207,24 +207,23 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
// Preconditions:
// * Filesystem.mu must be locked for at least reading.
// * isDir(parentInode) == true.
-func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) {
- if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
- return "", err
+func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error {
+ if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
}
- pc := rp.Component()
- if pc == "." || pc == ".." {
- return "", syserror.EEXIST
+ if name == "." || name == ".." {
+ return syserror.EEXIST
}
- if len(pc) > linux.NAME_MAX {
- return "", syserror.ENAMETOOLONG
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
}
- if _, ok := parent.children[pc]; ok {
- return "", syserror.EEXIST
+ if _, ok := parent.children[name]; ok {
+ return syserror.EEXIST
}
if parent.VFSDentry().IsDead() {
- return "", syserror.ENOENT
+ return syserror.ENOENT
}
- return pc, nil
+ return nil
}
// checkDeleteLocked checks that the file represented by vfsd may be deleted.
@@ -265,7 +264,7 @@ func (fs *Filesystem) Release(ctx context.Context) {
//
// Precondition: Filesystem.mu is held.
func (d *Dentry) releaseKeptDentriesLocked(ctx context.Context) {
- if d.inode.Keep() {
+ if d.inode.Keep() && d != d.fs.root {
d.decRefLocked(ctx)
}
@@ -352,10 +351,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if rp.Mount() != vd.Mount() {
return syserror.EXDEV
}
@@ -394,8 +396,8 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -430,10 +432,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
}
@@ -657,8 +662,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Can we create the dst dentry?
var dst *Dentry
- pc, err := checkCreateLocked(ctx, rp, dstDir)
- switch err {
+ pc := rp.Component()
+ switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err {
case nil:
// Ok, continue with rename as replacement.
case syserror.EEXIST:
@@ -822,10 +827,13 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index d9d76758a..eac578f25 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -568,13 +568,6 @@ func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) e
return o.Unlink(ctx, name, child)
}
-// +stateify savable
-type renameAcrossDifferentImplementationsError struct{}
-
-func (renameAcrossDifferentImplementationsError) Error() string {
- return "rename across inodes with different implementations"
-}
-
// Rename implements Inode.Rename.
//
// Precondition: Rename may only be called across two directory inodes with
@@ -585,13 +578,18 @@ func (renameAcrossDifferentImplementationsError) Error() string {
//
// Postcondition: reference on any replaced dentry transferred to caller.
func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error {
+ if !o.writable {
+ return syserror.EPERM
+ }
+
dst, ok := dstDir.(interface{}).(*OrderedChildren)
if !ok {
- return renameAcrossDifferentImplementationsError{}
+ return syserror.EXDEV
}
- if !o.writable || !dst.writable {
+ if !dst.writable {
return syserror.EPERM
}
+
// Note: There's a potential deadlock below if concurrent calls to Rename
// refer to the same src and dst directories in reverse. We avoid any
// ordering issues because the caller is required to serialize concurrent
@@ -662,7 +660,7 @@ var _ Inode = (*StaticDirectory)(nil)
func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
inode := &StaticDirectory{}
inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts)
- inode.EnableLeakCheck()
+ inode.InitRefs()
inode.OrderedChildren.Init(OrderedChildrenOptions{})
links := inode.OrderedChildren.Populate(children)
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index abb477c7d..c14abcff4 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -222,7 +222,9 @@ func (d *Dentry) IncRef() {
// d.refs may be 0 if d.fs.mu is locked, which serializes against
// d.cacheLocked().
r := atomic.AddInt64(&d.refs, 1)
- refsvfs2.LogIncRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogIncRef(d, r)
+ }
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
@@ -233,7 +235,9 @@ func (d *Dentry) TryIncRef() bool {
return false
}
if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
- refsvfs2.LogTryIncRef(d, r+1)
+ if d.LogRefs() {
+ refsvfs2.LogTryIncRef(d, r+1)
+ }
return true
}
}
@@ -242,7 +246,9 @@ func (d *Dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *Dentry) DecRef(ctx context.Context) {
r := atomic.AddInt64(&d.refs, -1)
- refsvfs2.LogDecRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogDecRef(d, r)
+ }
if r == 0 {
d.fs.mu.Lock()
d.cacheLocked(ctx)
@@ -254,7 +260,9 @@ func (d *Dentry) DecRef(ctx context.Context) {
func (d *Dentry) decRefLocked(ctx context.Context) {
r := atomic.AddInt64(&d.refs, -1)
- refsvfs2.LogDecRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogDecRef(d, r)
+ }
if r == 0 {
d.cacheLocked(ctx)
} else if r < 0 {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 2418eec44..e63588e33 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -109,7 +109,7 @@ func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credential
dir := &readonlyDir{}
dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- dir.EnableLeakCheck()
+ dir.InitRefs()
dir.IncLinks(dir.OrderedChildren.Populate(contents))
return dir
}
@@ -147,7 +147,7 @@ func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode
dir.fs = fs
dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
- dir.EnableLeakCheck()
+ dir.InitRefs()
dir.IncLinks(dir.OrderedChildren.Populate(contents))
return dir
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 4506642ca..469f3a33d 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -409,7 +409,7 @@ func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) er
if dirent.Name == "." || dirent.Name == ".." {
continue
}
- child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds)
+ child, _, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 04ca85f1a..bc07d72c0 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -121,63 +122,63 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
// * !rp.Done().
-func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, lookupLayer, error) {
if !d.isDir() {
- return nil, syserror.ENOTDIR
+ return nil, lookupLayerNone, syserror.ENOTDIR
}
if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
afterSymlink:
name := rp.Component()
if name == "." {
rp.Advance()
- return d, nil
+ return d, d.topLookupLayer(), nil
}
if name == ".." {
if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
} else if isRoot || d.parent == nil {
rp.Advance()
- return d, nil
+ return d, d.topLookupLayer(), nil
}
if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
rp.Advance()
- return d.parent, nil
+ return d.parent, d.parent.topLookupLayer(), nil
}
- child, err := fs.getChildLocked(ctx, d, name, ds)
+ child, topLookupLayer, err := fs.getChildLocked(ctx, d, name, ds)
if err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
target, err := child.readlink(ctx)
if err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
if err := rp.HandleSymlink(target); err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
goto afterSymlink // don't check the current directory again
}
rp.Advance()
- return child, nil
+ return child, topLookupLayer, nil
}
// Preconditions:
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
-func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, lookupLayer, error) {
if child, ok := parent.children[name]; ok {
- return child, nil
+ return child, child.topLookupLayer(), nil
}
- child, err := fs.lookupLocked(ctx, parent, name)
+ child, topLookupLayer, err := fs.lookupLocked(ctx, parent, name)
if err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
if parent.children == nil {
parent.children = make(map[string]*dentry)
@@ -185,16 +186,16 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
parent.children[name] = child
// child's refcount is initially 0, so it may be dropped after traversal.
*ds = appendDentry(*ds, child)
- return child, nil
+ return child, topLookupLayer, nil
}
// Preconditions:
// * fs.renameMu must be locked.
// * parent.dirMu must be locked.
-func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
+func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, lookupLayer, error) {
childPath := fspath.Parse(name)
child := fs.newDentry()
- existsOnAnyLayer := false
+ topLookupLayer := lookupLayerNone
var lookupErr error
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -215,7 +216,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
defer childVD.DecRef(ctx)
mask := uint32(linux.STATX_TYPE)
- if !existsOnAnyLayer {
+ if topLookupLayer == lookupLayerNone {
// Mode, UID, GID, and (for non-directories) inode number come from
// the topmost layer on which the file exists.
mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
@@ -238,10 +239,13 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
if isWhiteout(&stat) {
// This is a whiteout, so it "doesn't exist" on this layer, and
// layers below this one are ignored.
+ if isUpper {
+ topLookupLayer = lookupLayerUpperWhiteout
+ }
return false
}
isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR
- if existsOnAnyLayer && !isDir {
+ if topLookupLayer != lookupLayerNone && !isDir {
// Directories are not merged with non-directory files from lower
// layers; instead, layers including and below the first
// non-directory file are ignored. (This file must be a directory
@@ -258,8 +262,12 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
} else {
child.lowerVDs = append(child.lowerVDs, childVD)
}
- if !existsOnAnyLayer {
- existsOnAnyLayer = true
+ if topLookupLayer == lookupLayerNone {
+ if isUpper {
+ topLookupLayer = lookupLayerUpper
+ } else {
+ topLookupLayer = lookupLayerLower
+ }
child.mode = uint32(stat.Mode)
child.uid = stat.UID
child.gid = stat.GID
@@ -288,11 +296,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
if lookupErr != nil {
child.destroyLocked(ctx)
- return nil, lookupErr
+ return nil, topLookupLayer, lookupErr
}
- if !existsOnAnyLayer {
+ if !topLookupLayer.existsInOverlay() {
child.destroyLocked(ctx)
- return nil, syserror.ENOENT
+ return nil, topLookupLayer, syserror.ENOENT
}
// Device and inode numbers were copied from the topmost layer above;
@@ -306,7 +314,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
if err != nil {
ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err)
child.destroyLocked(ctx)
- return nil, err
+ return nil, topLookupLayer, err
}
child.devMajor = linux.UNNAMED_MAJOR
child.devMinor = childDevMinor
@@ -315,7 +323,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
parent.IncRef()
child.parent = parent
child.name = name
- return child, nil
+ return child, topLookupLayer, nil
}
// lookupLayerLocked is similar to lookupLocked, but only returns information
@@ -414,7 +422,7 @@ func (ll lookupLayer) existsInOverlay() bool {
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -434,7 +442,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
d := rp.Start().Impl().(*dentry)
for !rp.Done() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -469,9 +477,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if name == "." || name == ".." {
return syserror.EEXIST
}
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
if parent.vfsd.IsDead() {
return syserror.ENOENT
}
@@ -495,6 +500,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+
// Ensure that the parent directory is copied-up so that we can create the
// new file in the upper layer.
if err := parent.copyUpLocked(ctx); err != nil {
@@ -797,9 +806,9 @@ afterTrailingSymlink:
}
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
- child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
if err == syserror.ENOENT && mayCreate {
- fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds)
+ fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout)
parent.dirMu.Unlock()
return fd, err
}
@@ -899,7 +908,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *
// Preconditions:
// * parent.dirMu must be locked.
// * parent does not already contain a child named rp.Component().
-func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry, haveUpperWhiteout bool) (*vfs.FileDescription, error) {
creds := rp.Credentials()
if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil {
return nil, err
@@ -924,19 +933,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
Start: parent.upperVD,
Path: fspath.Parse(childName),
}
- // We don't know if a whiteout exists on the upper layer; speculatively
- // unlink it.
- //
- // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do
- // know whether a whiteout exists.
- var haveUpperWhiteout bool
- switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err {
- case nil:
- haveUpperWhiteout = true
- case syserror.ENOENT:
- haveUpperWhiteout = false
- default:
- return nil, err
+ // Unlink the whiteout if it exists.
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ log.Warningf("overlay.filesystem.createAndOpenLocked: failed to unlink whiteout: %v", err)
+ return nil, err
+ }
}
// Create the file on the upper layer, and get an FD representing it.
upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{
@@ -967,7 +969,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
// Re-lookup to get a dentry representing the new file, which is needed for
// the returned FD.
- child, err := fs.getChildLocked(ctx, parent, childName, ds)
+ child, _, err := fs.getChildLocked(ctx, parent, childName, ds)
if err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr))
@@ -1047,7 +1049,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// directory, we need to check for write permission on it.
oldParent.dirMu.Lock()
defer oldParent.dirMu.Unlock()
- renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
+ renamed, _, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
if err != nil {
return err
}
@@ -1079,20 +1081,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if newParent.vfsd.IsDead() {
return syserror.ENOENT
}
- replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName)
- if err != nil {
- return err
- }
var (
- replaced *dentry
- replacedVFSD *vfs.Dentry
- whiteouts map[string]bool
+ replaced *dentry
+ replacedVFSD *vfs.Dentry
+ replacedLayer lookupLayer
+ whiteouts map[string]bool
)
- if replacedLayer.existsInOverlay() {
- replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds)
- if err != nil {
- return err
- }
+ replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds)
+ if err != nil && err != syserror.ENOENT {
+ return err
+ }
+ if replaced != nil {
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
@@ -1296,7 +1295,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// Unlike UnlinkAt, we need a dentry representing the child directory being
// removed in order to verify that it's empty.
- child, err := fs.getChildLocked(ctx, parent, name, &ds)
+ child, _, err := fs.getChildLocked(ctx, parent, name, &ds)
if err != nil {
return err
}
@@ -1548,7 +1547,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if parentMode&linux.S_ISVTX != 0 {
// If the parent's sticky bit is set, we need a child dentry to get
// its owner.
- child, err = fs.getChildLocked(ctx, parent, name, &ds)
+ child, _, err = fs.getChildLocked(ctx, parent, name, &ds)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index f6c58f2e7..3492409b2 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -514,7 +514,9 @@ func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
// d.checkDropLocked().
r := atomic.AddInt64(&d.refs, 1)
- refsvfs2.LogIncRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogIncRef(d, r)
+ }
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
@@ -525,7 +527,9 @@ func (d *dentry) TryIncRef() bool {
return false
}
if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
- refsvfs2.LogTryIncRef(d, r+1)
+ if d.LogRefs() {
+ refsvfs2.LogTryIncRef(d, r+1)
+ }
return true
}
}
@@ -534,7 +538,9 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
r := atomic.AddInt64(&d.refs, -1)
- refsvfs2.LogDecRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogDecRef(d, r)
+ }
if r == 0 {
d.fs.renameMu.Lock()
d.checkDropLocked(ctx)
@@ -546,7 +552,9 @@ func (d *dentry) DecRef(ctx context.Context) {
func (d *dentry) decRefLocked(ctx context.Context) {
r := atomic.AddInt64(&d.refs, -1)
- refsvfs2.LogDecRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogDecRef(d, r)
+ }
if r == 0 {
d.checkDropLocked(ctx)
} else if r < 0 {
@@ -696,6 +704,13 @@ func (d *dentry) topLayer() vfs.VirtualDentry {
return vd
}
+func (d *dentry) topLookupLayer() lookupLayer {
+ if d.upperVD.Ok() {
+ return lookupLayerUpper
+ }
+ return lookupLayerLower
+}
+
func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index e44b79b68..0ecb592cf 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -101,7 +101,7 @@ type inode struct {
func newInode(ctx context.Context, fs *filesystem) *inode {
creds := auth.CredentialsFromContext(ctx)
return &inode{
- pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize),
ino: fs.Filesystem.NextIno(),
uid: creds.EffectiveKUID,
gid: creds.EffectiveKGID,
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index cb3c5e0fd..e001d5032 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -60,7 +60,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace,
// Note: credentials are overridden by taskOwnedInode.
subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- subInode.EnableLeakCheck()
+ subInode.InitRefs()
inode := &taskOwnedInode{Inode: subInode, owner: task}
return inode
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 19011b010..dc46a09bc 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -91,7 +91,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
taskInode := &taskInode{task: task}
// Note: credentials are overridden by taskOwnedInode.
taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
- taskInode.EnableLeakCheck()
+ taskInode.InitRefs()
inode := &taskOwnedInode{Inode: taskInode, owner: task}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index d268b44be..3ec4471f5 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -128,7 +128,7 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode {
},
}
inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
- inode.EnableLeakCheck()
+ inode.InitRefs()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
}
@@ -265,7 +265,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode {
},
}
inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
- inode.EnableLeakCheck()
+ inode.InitRefs()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index b81ea14bf..151d1f10d 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -83,7 +83,7 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns
cgroupControllers: cgroupControllers,
}
inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
- inode.EnableLeakCheck()
+ inode.InitRefs()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
links := inode.OrderedChildren.Populate(contents)
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 506a2a0f0..79bc3fe88 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -160,7 +160,7 @@ func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode
d := &dir{}
d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
- d.EnableLeakCheck()
+ d.InitRefs()
d.IncLinks(d.OrderedChildren.Populate(contents))
return d
}
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index d772db9e9..57e7b57b0 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/usermem"
)
// +stateify savable
@@ -32,7 +31,7 @@ type namedPipe struct {
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
- file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
+ file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize)}
file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 4ce859d57..85a3dfe20 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -402,7 +402,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth
i.mtime = now
// i.nlink initialized by caller
i.impl = impl
- i.refs.EnableLeakCheck()
+ i.refs.InitRefs()
}
// incLinksLocked increments i's link count.
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 2f6050cfd..4e8d63d51 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -276,9 +276,9 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
UID: parentStat.UID,
GID: parentStat.GID,
//TODO(b/156980949): Support passing other hash algorithms.
- HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: int64(offset),
- ReadSize: int64(merkletree.DigestSize(linux.FS_VERITY_HASH_ALG_SHA256)),
+ ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
Expected: parent.hash,
DataAndTreeInSameFile: true,
}); err != nil && err != io.EOF {
@@ -352,7 +352,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
UID: stat.UID,
GID: stat.GID,
//TODO(b/156980949): Support passing other hash algorithms.
- HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: 0,
// Set read size to 0 so only the metadata is verified.
ReadSize: 0,
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index de92878fd..faa862c55 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -79,6 +79,27 @@ var (
verityMu sync.RWMutex
)
+// HashAlgorithm is a type specifying the algorithm used to hash the file
+// content.
+type HashAlgorithm int
+
+// Currently supported hashing algorithms include SHA256 and SHA512.
+const (
+ SHA256 HashAlgorithm = iota
+ SHA512
+)
+
+func (alg HashAlgorithm) toLinuxHashAlg() int {
+ switch alg {
+ case SHA256:
+ return linux.FS_VERITY_HASH_ALG_SHA256
+ case SHA512:
+ return linux.FS_VERITY_HASH_ALG_SHA512
+ default:
+ return 0
+ }
+}
+
// FilesystemType implements vfs.FilesystemType.
//
// +stateify savable
@@ -108,6 +129,10 @@ type filesystem struct {
// stores the root hash of the whole file system in bytes.
rootDentry *dentry
+ // alg is the algorithms used to hash the files in the verity file
+ // system.
+ alg HashAlgorithm
+
// renameMu synchronizes renaming with non-renaming operations in order
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
@@ -136,6 +161,10 @@ type InternalFilesystemOptions struct {
// LowerName is the name of the filesystem wrapped by verity fs.
LowerName string
+ // Alg is the algorithms used to hash the files in the verity file
+ // system.
+ Alg HashAlgorithm
+
// RootHash is the root hash of the overall verity file system.
RootHash []byte
@@ -194,6 +223,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs := &filesystem{
creds: creds.Fork(),
+ alg: iopts.Alg,
lowerMount: mnt,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
}
@@ -350,7 +380,9 @@ func (fs *filesystem) newDentry() *dentry {
// IncRef implements vfs.DentryImpl.IncRef.
func (d *dentry) IncRef() {
r := atomic.AddInt64(&d.refs, 1)
- refsvfs2.LogIncRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogIncRef(d, r)
+ }
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
@@ -361,7 +393,9 @@ func (d *dentry) TryIncRef() bool {
return false
}
if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
- refsvfs2.LogTryIncRef(d, r+1)
+ if d.LogRefs() {
+ refsvfs2.LogTryIncRef(d, r+1)
+ }
return true
}
}
@@ -370,7 +404,9 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
r := atomic.AddInt64(&d.refs, -1)
- refsvfs2.LogDecRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogDecRef(d, r)
+ }
if r == 0 {
d.fs.renameMu.Lock()
d.checkDropLocked(ctx)
@@ -382,7 +418,9 @@ func (d *dentry) DecRef(ctx context.Context) {
func (d *dentry) decRefLocked(ctx context.Context) {
r := atomic.AddInt64(&d.refs, -1)
- refsvfs2.LogDecRef(d, r)
+ if d.LogRefs() {
+ refsvfs2.LogDecRef(d, r)
+ }
if r == 0 {
d.checkDropLocked(ctx)
} else if r < 0 {
@@ -627,7 +665,7 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
TreeReader: &merkleReader,
TreeWriter: &merkleWriter,
//TODO(b/156980949): Support passing other hash algorithms.
- HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
}
switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
@@ -873,7 +911,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
UID: fd.d.uid,
GID: fd.d.gid,
//TODO(b/156980949): Support passing other hash algorithms.
- HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
ReadOffset: offset,
ReadSize: dst.NumBytes(),
Expected: fd.d.hash,
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index c647cbfd3..b2da9dd96 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -43,7 +43,7 @@ const maxDataSize = 100000
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
-func newVerityRoot(t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
k, err := testutil.Boot()
if err != nil {
t.Fatalf("testutil.Boot: %v", err)
@@ -70,6 +70,7 @@ func newVerityRoot(t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *ke
InternalData: InternalFilesystemOptions{
RootMerkleFileName: rootMerkleFilename,
LowerName: "tmpfs",
+ Alg: hashAlg,
AllowRuntimeEnable: true,
NoCrashOnVerificationFailure: true,
},
@@ -161,280 +162,296 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
return nil
}
+var hashAlgs = []HashAlgorithm{SHA256, SHA512}
+
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
func TestOpen(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Ensure that the corresponding Merkle tree file is created.
- lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
- t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
- }
-
- // Ensure the root merkle tree file is created.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + rootMerkleFilename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
- t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Ensure that the corresponding Merkle tree file is created.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ }
+
+ // Ensure the root merkle tree file is created.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + rootMerkleFilename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ }
}
}
// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity
// file succeeds after enabling verity for it.
func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- buf := make([]byte, size)
- n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
- if err != nil && err != io.EOF {
- t.Fatalf("fd.PRead: %v", err)
- }
-
- if n != int64(size) {
- t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
}
}
// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity
// file succeeds after enabling verity for it.
func TestReadUnmodifiedFileSucceeds(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- buf := make([]byte, size)
- n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
- if err != nil && err != io.EOF {
- t.Fatalf("fd.Read: %v", err)
- }
-
- if n != int64(size) {
- t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
}
}
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Ensure reopening the verity enabled file succeeds.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != nil {
- t.Errorf("reopen enabled file failed: %v", err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Ensure reopening the verity enabled file succeeds.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != nil {
+ t.Errorf("reopen enabled file failed: %v", err)
+ }
}
}
// TestPReadModifiedFileFails ensures that read from a modified verity file
// fails.
func TestPReadModifiedFileFails(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
- // Confirm that read from the modified file fails.
- buf := make([]byte, size)
- if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- t.Fatalf("fd.PRead succeeded, expected failure")
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.PRead succeeded, expected failure")
+ }
}
}
// TestReadModifiedFileFails ensures that read from a modified verity file
// fails.
func TestReadModifiedFileFails(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
- // Confirm that read from the modified file fails.
- buf := make([]byte, size)
- if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil {
- t.Fatalf("fd.Read succeeded, expected failure")
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.Read succeeded, expected failure")
+ }
}
}
// TestModifiedMerkleFails ensures that read from a verity file fails if the
// corresponding Merkle tree file is modified.
func TestModifiedMerkleFails(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerMerkleFD that's read/writable.
- lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
-
- lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerMerkleVD,
- Start: lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- // Flip a random bit in the Merkle tree file.
- stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- t.Fatalf("stat: %v", err)
- }
- merkleSize := int(stat.Size)
- if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
-
- // Confirm that read from a file with modified Merkle tree fails.
- buf := make([]byte, size)
- if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- fmt.Println(buf)
- t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
+
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the Merkle tree file.
+ stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ merkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from a file with modified Merkle tree fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ fmt.Println(buf)
+ t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ }
}
}
@@ -442,140 +459,146 @@ func TestModifiedMerkleFails(t *testing.T) {
// verity enabled directory fails if the hashes related to the target file in
// the parent Merkle tree file is modified.
func TestModifiedParentMerkleFails(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Enable verity on the parent directory.
- parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerMerkleFD that's read/writable.
- parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
-
- parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: parentLowerMerkleVD,
- Start: parentLowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- // Flip a random bit in the parent Merkle tree file.
- // This parent directory contains only one child, so any random
- // modification in the parent Merkle tree should cause verification
- // failure when opening the child file.
- stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- t.Fatalf("stat: %v", err)
- }
- parentMerkleSize := int(stat.Size)
- if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
-
- parentLowerMerkleFD.DecRef(ctx)
-
- // Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err == nil {
- t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Enable verity on the parent directory.
+ parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
+
+ parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: parentLowerMerkleVD,
+ Start: parentLowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the parent Merkle tree file.
+ // This parent directory contains only one child, so any random
+ // modification in the parent Merkle tree should cause verification
+ // failure when opening the child file.
+ stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ parentMerkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ parentLowerMerkleFD.DecRef(ctx)
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err == nil {
+ t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ }
}
}
// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file
// succeeds after enabling verity for it.
func TestUnmodifiedStatSucceeds(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirms stat succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
- if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
- t.Errorf("fd.Stat: %v", err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms stat succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
+ t.Errorf("fd.Stat: %v", err)
+ }
}
}
// TestModifiedStatFails checks that getting stat for a file with modified stat
// should fail.
func TestModifiedStatFails(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
- lowerFD := fd.Impl().(*fileDescription).lowerFD
- // Change the stat of the underlying file, and check that stat fails.
- if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: uint32(linux.STATX_MODE),
- Mode: 0777,
- },
- }); err != nil {
- t.Fatalf("lowerFD.SetStat: %v", err)
- }
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ lowerFD := fd.Impl().(*fileDescription).lowerFD
+ // Change the stat of the underlying file, and check that stat fails.
+ if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: uint32(linux.STATX_MODE),
+ Mode: 0777,
+ },
+ }); err != nil {
+ t.Fatalf("lowerFD.SetStat: %v", err)
+ }
- if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
- t.Errorf("fd.Stat succeeded when it should fail")
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
+ t.Errorf("fd.Stat succeeded when it should fail")
+ }
}
}
@@ -616,84 +639,86 @@ func TestOpenDeletedFileFails(t *testing.T) {
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) {
- vfsObj, root, ctx, err := newVerityRoot(t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
- rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
- if tc.remove {
- if tc.changeFile {
- if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(filename),
- }); err != nil {
- t.Fatalf("UnlinkAt: %v", err)
+ rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
+ if tc.remove {
+ if tc.changeFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
}
- }
- if tc.changeMerkleFile {
- if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + filename),
- }); err != nil {
- t.Fatalf("UnlinkAt: %v", err)
+ if tc.changeMerkleFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
}
- }
- } else {
- newFilename := "renamed-test-file"
- if tc.changeFile {
- if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(filename),
- }, &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(newFilename),
- }, &vfs.RenameOptions{}); err != nil {
- t.Fatalf("RenameAt: %v", err)
+ } else {
+ newFilename := "renamed-test-file"
+ if tc.changeFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("RenameAt: %v", err)
+ }
}
- }
- if tc.changeMerkleFile {
- if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.PathOperation{
- Root: rootLowerVD,
- Start: rootLowerVD,
- Path: fspath.Parse(merklePrefix + newFilename),
- }, &vfs.RenameOptions{}); err != nil {
- t.Fatalf("UnlinkAt: %v", err)
+ if tc.changeMerkleFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
}
}
- }
- // Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != syserror.EIO {
- t.Errorf("got OpenAt error: %v, expected EIO", err)
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != syserror.EIO {
+ t.Errorf("got OpenAt error: %v, expected EIO", err)
+ }
}
})
}
diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD
index 364a78306..db3b0d0a0 100644
--- a/pkg/sentry/hostfd/BUILD
+++ b/pkg/sentry/hostfd/BUILD
@@ -6,10 +6,12 @@ go_library(
name = "hostfd",
srcs = [
"hostfd.go",
+ "hostfd_linux.go",
"hostfd_unsafe.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/log",
"//pkg/safemem",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go
new file mode 100644
index 000000000..1cabc848f
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd_linux.go
@@ -0,0 +1,18 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostfd
+
+// maxIov is the maximum permitted size of a struct iovec array.
+const maxIov = 1024 // UIO_MAXIOV
diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go
index cd4dc67fb..694371b1c 100644
--- a/pkg/sentry/hostfd/hostfd_unsafe.go
+++ b/pkg/sentry/hostfd/hostfd_unsafe.go
@@ -20,6 +20,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
)
@@ -44,6 +45,10 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6
}
} else {
iovs := safemem.IovecsFromBlockSeq(dsts)
+ if len(iovs) > maxIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
+ iovs = iovs[:maxIov]
+ }
n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
if e != 0 {
@@ -76,6 +81,10 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint
}
} else {
iovs := safemem.IovecsFromBlockSeq(srcs)
+ if len(iovs) > maxIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
+ iovs = iovs[:maxIov]
+ }
n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
if e != 0 {
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index 3476551f3..470d8bf83 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -43,7 +43,7 @@ func (f *FDTable) initNoLeakCheck() {
// init initializes the table with leak checking.
func (f *FDTable) init() {
f.initNoLeakCheck()
- f.EnableLeakCheck()
+ f.InitRefs()
}
// get gets a file entry.
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index 41fb2a784..dfde4deee 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -63,7 +63,7 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext {
cwd: cwd,
umask: umask,
}
- f.EnableLeakCheck()
+ f.InitRefs()
return &f
}
@@ -76,7 +76,7 @@ func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext {
cwdVFS2: cwd,
umask: umask,
}
- f.EnableLeakCheck()
+ f.InitRefs()
return &f
}
@@ -137,7 +137,7 @@ func (f *FSContext) Fork() *FSContext {
rootVFS2: f.rootVFS2,
umask: f.umask,
}
- ctx.EnableLeakCheck()
+ ctx.InitRefs()
return ctx
}
diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go
index b87e40dd1..9545bb5ef 100644
--- a/pkg/sentry/kernel/ipc_namespace.go
+++ b/pkg/sentry/kernel/ipc_namespace.go
@@ -41,7 +41,7 @@ func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace {
semaphores: semaphore.NewRegistry(userNS),
shms: shm.NewRegistry(userNS),
}
- ns.EnableLeakCheck()
+ ns.InitRefs()
return ns
}
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index ce0db5583..d6fb0fdb8 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
type sleeper struct {
@@ -66,7 +65,8 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag
d := fs.NewDirent(ctx, inode, "pipe")
file, err := n.GetFile(ctx, d, flags)
if err != nil {
- t.Fatalf("open with flags %+v failed: %v", flags, err)
+ t.Errorf("open with flags %+v failed: %v", flags, err)
+ return nil, err
}
if doneChan != nil {
doneChan <- struct{}{}
@@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.
}
func newNamedPipe(t *testing.T) *Pipe {
- return NewPipe(true, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(true, DefaultPipeSize)
}
func newAnonPipe(t *testing.T) *Pipe {
- return NewPipe(false, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(false, DefaultPipeSize)
}
// assertRecvBlocks ensures that a recv attempt on c blocks for at least
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 67beb0ad6..b989e14c7 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -26,18 +26,27 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
// MinimumPipeSize is a hard limit of the minimum size of a pipe.
- MinimumPipeSize = 64 << 10
+ // It corresponds to fs/pipe.c:pipe_min_size.
+ MinimumPipeSize = usermem.PageSize
+
+ // MaximumPipeSize is a hard limit on the maximum size of a pipe.
+ // It corresponds to fs/pipe.c:pipe_max_size.
+ MaximumPipeSize = 1048576
// DefaultPipeSize is the system-wide default size of a pipe in bytes.
- DefaultPipeSize = MinimumPipeSize
+ // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS.
+ DefaultPipeSize = 16 * usermem.PageSize
- // MaximumPipeSize is a hard limit on the maximum size of a pipe.
- MaximumPipeSize = 8 << 20
+ // atomicIOBytes is the maximum number of bytes that the pipe will
+ // guarantee atomic reads or writes atomically.
+ // It corresponds to limits.h:PIPE_BUF.
+ atomicIOBytes = 4096
)
// Pipe is an encapsulation of a platform-independent pipe.
@@ -53,12 +62,6 @@ type Pipe struct {
// This value is immutable.
isNamed bool
- // atomicIOBytes is the maximum number of bytes that the pipe will
- // guarantee atomic reads or writes atomically.
- //
- // This value is immutable.
- atomicIOBytes int64
-
// The number of active readers for this pipe.
//
// Access atomically.
@@ -94,47 +97,34 @@ type Pipe struct {
// NewPipe initializes and returns a pipe.
//
-// N.B. The size and atomicIOBytes will be bounded.
-func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
+// N.B. The size will be bounded.
+func NewPipe(isNamed bool, sizeBytes int64) *Pipe {
if sizeBytes < MinimumPipeSize {
sizeBytes = MinimumPipeSize
}
if sizeBytes > MaximumPipeSize {
sizeBytes = MaximumPipeSize
}
- if atomicIOBytes <= 0 {
- atomicIOBytes = 1
- }
- if atomicIOBytes > sizeBytes {
- atomicIOBytes = sizeBytes
- }
var p Pipe
- initPipe(&p, isNamed, sizeBytes, atomicIOBytes)
+ initPipe(&p, isNamed, sizeBytes)
return &p
}
-func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) {
+func initPipe(pipe *Pipe, isNamed bool, sizeBytes int64) {
if sizeBytes < MinimumPipeSize {
sizeBytes = MinimumPipeSize
}
if sizeBytes > MaximumPipeSize {
sizeBytes = MaximumPipeSize
}
- if atomicIOBytes <= 0 {
- atomicIOBytes = 1
- }
- if atomicIOBytes > sizeBytes {
- atomicIOBytes = sizeBytes
- }
pipe.isNamed = isNamed
pipe.max = sizeBytes
- pipe.atomicIOBytes = atomicIOBytes
}
// NewConnectedPipe initializes a pipe and returns a pair of objects
// representing the read and write ends of the pipe.
-func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
- p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes)
+func NewConnectedPipe(ctx context.Context, sizeBytes int64) (*fs.File, *fs.File) {
+ p := NewPipe(false /* isNamed */, sizeBytes)
// Build an fs.Dirent for the pipe which will be shared by both
// returned files.
@@ -264,7 +254,7 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
wanted := ops.left()
avail := p.max - p.view.Size()
if wanted > avail {
- if wanted <= p.atomicIOBytes {
+ if wanted <= atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
ops.limit(avail)
diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go
index fe97e9800..3dd739080 100644
--- a/pkg/sentry/kernel/pipe/pipe_test.go
+++ b/pkg/sentry/kernel/pipe/pipe_test.go
@@ -26,7 +26,7 @@ import (
func TestPipeRW(t *testing.T) {
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, 65536, 4096)
+ r, w := NewConnectedPipe(ctx, 65536)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -46,7 +46,7 @@ func TestPipeRW(t *testing.T) {
func TestPipeReadBlock(t *testing.T) {
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, 65536, 4096)
+ r, w := NewConnectedPipe(ctx, 65536)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -61,7 +61,7 @@ func TestPipeWriteBlock(t *testing.T) {
const capacity = MinimumPipeSize
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes)
+ r, w := NewConnectedPipe(ctx, capacity)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -76,7 +76,7 @@ func TestPipeWriteUntilEnd(t *testing.T) {
const atomicIOBytes = 2
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes)
+ r, w := NewConnectedPipe(ctx, atomicIOBytes)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -116,7 +116,8 @@ func TestPipeWriteUntilEnd(t *testing.T) {
}
}
if err != nil {
- t.Fatalf("Readv: got unexpected error %v", err)
+ t.Errorf("Readv: got unexpected error %v", err)
+ return
}
}
}()
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index d96bf253b..7b23cbe86 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -54,9 +54,9 @@ type VFSPipe struct {
}
// NewVFSPipe returns an initialized VFSPipe.
-func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
+func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe {
var vp VFSPipe
- initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes)
+ initPipe(&vp.pipe, isNamed, sizeBytes)
return &vp
}
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 1145faf13..1abfe2201 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -1000,7 +1000,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
// at the address specified by the data parameter, and the return value
// is the error flag." - ptrace(2)
word := t.Arch().Native(0)
- if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
+ if _, err := word.CopyIn(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
return err
}
_, err := word.CopyOut(t, data)
@@ -1008,7 +1008,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA:
word := t.Arch().Native(uintptr(data))
- _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr)
+ _, err := word.CopyOut(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr)
return err
case linux.PTRACE_GETREGSET:
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index c39ecfb8f..b99c0bffa 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -103,6 +103,7 @@ type waiter struct {
waiterEntry
// value represents how much resource the waiter needs to wake up.
+ // The value is either 0 or negative.
value int16
ch chan struct{}
}
@@ -423,6 +424,42 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) {
return sem.pid, nil
}
+func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // The calling process must have read permission on the semaphore set.
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return 0, syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return 0, syserror.ERANGE
+ }
+ var cnt uint16
+ for w := sem.waiters.Front(); w != nil; w = w.Next() {
+ if pred(w) {
+ cnt++
+ }
+ }
+ return cnt, nil
+}
+
+// CountZeroWaiters returns number of waiters waiting for the sem's value to increase.
+func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) {
+ return s.countWaiters(num, creds, func(w *waiter) bool {
+ return w.value == 0
+ })
+}
+
+// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero.
+func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) {
+ return s.countWaiters(num, creds, func(w *waiter) bool {
+ return w.value < 0
+ })
+}
+
// ExecuteOps attempts to execute a list of operations to the set. It only
// succeeds when all operations can be applied. No changes are made if it fails.
//
@@ -575,11 +612,18 @@ func (s *Set) destroy() {
}
}
+func abs(val int16) int16 {
+ if val < 0 {
+ return -val
+ }
+ return val
+}
+
// wakeWaiters goes over all waiters and checks which of them can be notified.
func (s *sem) wakeWaiters() {
// Note that this will release all waiters waiting for 0 too.
for w := s.waiters.Front(); w != nil; {
- if s.value < w.value {
+ if s.value < abs(w.value) {
// Still blocked, skip it.
w = w.Next()
continue
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index df5c8421b..0cd9e2533 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -295,7 +295,7 @@ func (tg *ThreadGroup) createSession() error {
id: SessionID(id),
leader: tg,
}
- s.EnableLeakCheck()
+ s.InitRefs()
// Create a new ProcessGroup, belonging to that Session.
// This also has a single reference (assigned below).
@@ -309,7 +309,7 @@ func (tg *ThreadGroup) createSession() error {
session: s,
ancestors: 0,
}
- pg.refs.EnableLeakCheck()
+ pg.refs.InitRefs()
// Tie them and return the result.
s.processGroups.PushBack(pg)
@@ -395,7 +395,7 @@ func (tg *ThreadGroup) CreateProcessGroup() error {
originator: tg,
session: tg.processGroup.session,
}
- pg.refs.EnableLeakCheck()
+ pg.refs.InitRefs()
if tg.leader.parent != nil && tg.leader.parent.tg.processGroup.session == pg.session {
pg.ancestors++
@@ -477,20 +477,20 @@ func (tg *ThreadGroup) Session() *Session {
//
// If this group isn't visible in this namespace, zero will be returned. It is
// the callers responsibility to check that before using this function.
-func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.sids[s]
+func (ns *PIDNamespace) IDOfSession(s *Session) SessionID {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.sids[s]
}
// SessionWithID returns the Session with the given ID in the PID namespace ns,
// or nil if that given ID is not defined in this namespace.
//
// A reference is not taken on the session.
-func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.sessions[id]
+func (ns *PIDNamespace) SessionWithID(id SessionID) *Session {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.sessions[id]
}
// ProcessGroup returns the ThreadGroup's ProcessGroup.
@@ -505,18 +505,18 @@ func (tg *ThreadGroup) ProcessGroup() *ProcessGroup {
// IDOfProcessGroup returns the process group assigned to pg in PID namespace ns.
//
// The same constraints apply as IDOfSession.
-func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.pgids[pg]
+func (ns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.pgids[pg]
}
// ProcessGroupWithID returns the ProcessGroup with the given ID in the PID
// namespace ns, or nil if that given ID is not defined in this namespace.
//
// A reference is not taken on the process group.
-func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.processGroups[id]
+func (ns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.processGroups[id]
}
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index ebbebf46b..92d60ba78 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -251,7 +251,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi
creatorPID: pid,
changeTime: ktime.NowFromContext(ctx),
}
- shm.EnableLeakCheck()
+ shm.InitRefs()
// Find the next available ID.
for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ {
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 682080c14..527344162 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -355,7 +355,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
}
if opts.ChildSetTID {
ctid := nt.ThreadID()
- ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
+ ctid.CopyOut(nt.CopyContext(t, usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
}
ntid := t.tg.pidns.IDOfTask(nt)
if opts.ParentSetTID {
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index ce134bf54..94dabbcd8 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -18,7 +18,8 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -281,29 +282,89 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp
}, nil
}
-// copyContext implements marshal.CopyContext. It wraps a task to allow copying
-// memory to and from the task memory with custom usermem.IOOpts.
-type copyContext struct {
- *Task
+type taskCopyContext struct {
+ ctx context.Context
+ t *Task
opts usermem.IOOpts
}
-// AsCopyContext wraps the task and returns it as CopyContext.
-func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext {
- return &copyContext{t, opts}
+// CopyContext returns a marshal.CopyContext that copies to/from t's address
+// space using opts.
+func (t *Task) CopyContext(ctx context.Context, opts usermem.IOOpts) *taskCopyContext {
+ return &taskCopyContext{
+ ctx: ctx,
+ t: t,
+ opts: opts,
+ }
+}
+
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (cc *taskCopyContext) CopyScratchBuffer(size int) []byte {
+ if ctxTask, ok := cc.ctx.(*Task); ok {
+ return ctxTask.CopyScratchBuffer(size)
+ }
+ return make([]byte, size)
+}
+
+func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) {
+ cc.t.mu.Lock()
+ tmm := cc.t.MemoryManager()
+ cc.t.mu.Unlock()
+ if !tmm.IncUsers() {
+ return nil, syserror.EFAULT
+ }
+ return tmm, nil
+}
+
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ tmm, err := cc.getMemoryManager()
+ if err != nil {
+ return 0, err
+ }
+ defer tmm.DecUsers(cc.ctx)
+ return tmm.CopyIn(cc.ctx, addr, dst, cc.opts)
+}
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ tmm, err := cc.getMemoryManager()
+ if err != nil {
+ return 0, err
+ }
+ defer tmm.DecUsers(cc.ctx)
+ return tmm.CopyOut(cc.ctx, addr, src, cc.opts)
+}
+
+type ownTaskCopyContext struct {
+ t *Task
+ opts usermem.IOOpts
+}
+
+// OwnCopyContext returns a marshal.CopyContext that copies to/from t's address
+// space using opts. The returned CopyContext may only be used by t's task
+// goroutine.
+//
+// Since t already implements marshal.CopyContext, this is only needed to
+// override the usermem.IOOpts used for the copy.
+func (t *Task) OwnCopyContext(opts usermem.IOOpts) *ownTaskCopyContext {
+ return &ownTaskCopyContext{
+ t: t,
+ opts: opts,
+ }
}
-// CopyInString copies a string in from the task's memory.
-func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) {
- return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts)
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte {
+ return cc.t.CopyScratchBuffer(size)
}
-// CopyInBytes copies task memory into dst from an IO context.
-func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
- return t.MemoryManager().CopyIn(t, addr, dst, t.opts)
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts)
}
-// CopyOutBytes copies src into task memoryfrom an IO context.
-func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
- return t.MemoryManager().CopyOut(t, addr, src, t.opts)
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts)
}
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index 9bc452e67..9e5c2d26f 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -115,7 +115,7 @@ func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error {
}
if old != v.seq {
- return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq)
+ return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d; application may hang or get incorrect time from the VDSO", old, v.seq)
}
v.seq = next
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 7bf48cb2c..4c8cd38ed 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -252,7 +252,7 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) {
return nil, err
}
m := aioMappable{mfp: mfp, fr: fr}
- m.EnableLeakCheck()
+ m.InitRefs()
return &m, nil
}
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index 2dbe5b751..48d8b6a2b 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -44,7 +44,7 @@ type SpecialMappable struct {
// Preconditions: fr.Length() != 0.
func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable {
m := SpecialMappable{mfp: mfp, fr: fr, name: name}
- m.EnableLeakCheck()
+ m.InitRefs()
return &m
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 0a54dd30d..acad4c793 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -79,6 +79,18 @@ func bluepillStopGuest(c *vCPU) {
c.runData.requestInterruptWindow = 0
}
+// bluepillSigBus is reponsible for injecting NMI to trigger sigbus.
+//
+//go:nosplit
+func bluepillSigBus(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_NMI, 0); errno != 0 {
+ throw("NMI injection failed")
+ }
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
//
//go:nosplit
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 58f3d6fdd..965ad66b5 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -27,15 +27,20 @@ var (
// The action for bluepillSignal is changed by sigaction().
bluepillSignal = syscall.SIGILL
- // vcpuSErr is the event of system error.
- vcpuSErr = kvmVcpuEvents{
+ // vcpuSErrBounce is the event of system error for bouncing KVM.
+ vcpuSErrBounce = kvmVcpuEvents{
exception: exception{
sErrPending: 1,
- sErrHasEsr: 0,
- pad: [6]uint8{0, 0, 0, 0, 0, 0},
- sErrEsr: 1,
},
- rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ }
+
+ // vcpuSErrNMI is the event of system error to trigger sigbus.
+ vcpuSErrNMI = kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ sErrHasEsr: 1,
+ sErrEsr: _ESR_ELx_SERR_NMI,
+ },
}
)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index b35c930e2..9433d4da5 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -80,11 +80,24 @@ func getHypercallID(addr uintptr) int {
//
//go:nosplit
func bluepillStopGuest(c *vCPU) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
- uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 {
+ uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
+ throw("sErr injection failed")
+ }
+}
+
+// bluepillSigBus is reponsible for injecting sError to trigger sigbus.
+//
+//go:nosplit
+func bluepillSigBus(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
throw("sErr injection failed")
}
}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index eb05950cd..75085ac6a 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -146,12 +146,7 @@ func bluepillHandler(context unsafe.Pointer) {
// MMIO exit we receive EFAULT from the run ioctl. We
// always inject an NMI here since we may be in kernel
// mode and have interrupts disabled.
- if _, _, errno := syscall.RawSyscall( // escapes: no.
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_NMI, 0); errno != 0 {
- throw("NMI injection failed")
- }
+ bluepillSigBus(c)
continue // Rerun vCPU.
default:
throw("run failed")
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index dd45ad10b..5979aef97 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -158,8 +158,7 @@ func (*KVM) MaxUserAddress() usermem.Addr {
// NewAddressSpace returns a new pagetable root.
func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
// Allocate page tables and install system mappings.
- pageTables := pagetables.New(newAllocator())
- k.machine.mapUpperHalf(pageTables)
+ pageTables := pagetables.NewWithUpper(newAllocator(), k.machine.upperSharedPageTables, ring0.KernelStartAddress)
// Return the new address space.
return &addressSpace{
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 5831b9345..b060d9544 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -151,6 +151,9 @@ const (
_ESR_SEGV_PEMERR_L1 = 0xd
_ESR_SEGV_PEMERR_L2 = 0xe
_ESR_SEGV_PEMERR_L3 = 0xf
+
+ // Custom ISS field definitions for system error.
+ _ESR_ELx_SERR_NMI = 0x1
)
// Arm64: MMIO base address used to dispatch hypercalls.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index f70d761fd..e2fffc99b 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -41,6 +41,9 @@ type machine struct {
// slots are currently being updated, and the caller should retry.
nextSlot uint32
+ // upperSharedPageTables tracks the read-only shared upper of all the pagetables.
+ upperSharedPageTables *pagetables.PageTables
+
// kernel is the set of global structures.
kernel ring0.Kernel
@@ -199,9 +202,7 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
- m.kernel.Init(ring0.KernelOpts{
- PageTables: pagetables.New(newAllocator()),
- }, m.maxVCPUs)
+ m.kernel.Init(m.maxVCPUs)
// Pull the maximum slots.
maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS)
@@ -213,6 +214,13 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of slots is %d.", m.maxSlots)
m.usedSlots = make([]uintptr, m.maxSlots)
+ // Create the upper shared pagetables and kernel(sentry) pagetables.
+ m.upperSharedPageTables = pagetables.New(newAllocator())
+ m.mapUpperHalf(m.upperSharedPageTables)
+ m.upperSharedPageTables.Allocator.(*allocator).base.Drain()
+ m.upperSharedPageTables.MarkReadOnlyShared()
+ m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -226,7 +234,6 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
- m.mapUpperHalf(m.kernel.PageTables)
var physicalRegionsReadOnly []physicalRegion
var physicalRegionsAvailable []physicalRegion
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index a8b729e62..8e03c310d 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -432,30 +432,27 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return physicalRegions
}
-var execRegions = func() (regions []region) {
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ // Map all the executible regions so that all the entry functions
+ // are mapped in the upper half.
applyVirtualRegions(func(vr virtualRegion) {
if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
return
}
+
if vr.accessType.Execute {
- regions = append(regions, vr.region)
+ r := vr.region
+ physical, length, ok := translateToPhysical(r.virtual)
+ if !ok || length < r.length {
+ panic("impossible translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ r.length,
+ pagetables.MapOpts{AccessType: usermem.Execute},
+ physical)
}
})
- return
-}()
-
-func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
- for _, r := range execRegions {
- physical, length, ok := translateToPhysical(r.virtual)
- if !ok || length < r.length {
- panic("impossilbe translation")
- }
- pageTable.Map(
- usermem.Addr(ring0.KernelStartAddress|r.virtual),
- r.length,
- pagetables.MapOpts{AccessType: usermem.Execute},
- physical)
- }
for start, end := range m.kernel.EntryRegions() {
regionLen := end - start
physical, length, ok := translateToPhysical(start)
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 1344ed3c9..fd92c3873 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -221,7 +221,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info)
} else if !ring0.IsCanonical(regs.Sp) {
- return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
+ return nonCanonical(regs.Sp, int32(syscall.SIGSEGV), info)
}
// Assign PCIDs.
@@ -257,11 +257,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.PageFault:
return c.fault(int32(syscall.SIGSEGV), info)
+ case ring0.El0ErrNMI:
+ return c.fault(int32(syscall.SIGBUS), info)
case ring0.Vector(bounce): // ring0.VirtualizationException
return usermem.NoAccess, platform.ErrContextInterrupt
- case ring0.El0Sync_undef:
+ case ring0.El0SyncUndef:
return c.fault(int32(syscall.SIGILL), info)
- case ring0.El1Sync_undef:
+ case ring0.El1SyncUndef:
*info = arch.SignalInfo{
Signo: int32(syscall.SIGILL),
Code: 1, // ILL_ILLOPC (illegal opcode).
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index 87a573cc4..327d48465 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -58,46 +58,55 @@ type Vector uintptr
// Exception vectors.
const (
- El1SyncInvalid = iota
- El1IrqInvalid
- El1FiqInvalid
- El1ErrorInvalid
+ El1InvSync = iota
+ El1InvIrq
+ El1InvFiq
+ El1InvError
+
El1Sync
El1Irq
El1Fiq
- El1Error
+ El1Err
+
El0Sync
El0Irq
El0Fiq
- El0Error
- El0Sync_invalid
- El0Irq_invalid
- El0Fiq_invalid
- El0Error_invalid
- El1Sync_da
- El1Sync_ia
- El1Sync_sp_pc
- El1Sync_undef
- El1Sync_dbg
- El1Sync_inv
- El0Sync_svc
- El0Sync_da
- El0Sync_ia
- El0Sync_fpsimd_acc
- El0Sync_sve_acc
- El0Sync_sys
- El0Sync_sp_pc
- El0Sync_undef
- El0Sync_dbg
- El0Sync_inv
+ El0Err
+
+ El0InvSync
+ El0InvIrq
+ El0InvFiq
+ El0InvErr
+
+ El1SyncDa
+ El1SyncIa
+ El1SyncSpPc
+ El1SyncUndef
+ El1SyncDbg
+ El1SyncInv
+
+ El0SyncSVC
+ El0SyncDa
+ El0SyncIa
+ El0SyncFpsimdAcc
+ El0SyncSveAcc
+ El0SyncSys
+ El0SyncSpPc
+ El0SyncUndef
+ El0SyncDbg
+ El0SyncInv
+
+ El0ErrNMI
+ El0ErrBounce
+
_NR_INTERRUPTS
)
// System call vectors.
const (
- Syscall Vector = El0Sync_svc
- PageFault Vector = El0Sync_da
- VirtualizationException Vector = El0Error
+ Syscall Vector = El0SyncSVC
+ PageFault Vector = El0SyncDa
+ VirtualizationException Vector = El0ErrBounce
)
// VirtualAddressBits returns the number bits available for virtual addresses.
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index e6daf24df..f9765771e 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -23,6 +23,9 @@ import (
//
// This contains global state, shared by multiple CPUs.
type Kernel struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+
KernelArchState
}
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 00899273e..7a2275558 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -66,17 +66,9 @@ var (
KernelDataSegment SegmentDescriptor
)
-// KernelOpts has initialization options for the kernel.
-type KernelOpts struct {
- // PageTables are the kernel pagetables; this must be provided.
- PageTables *pagetables.PageTables
-}
-
// KernelArchState contains architecture-specific state.
type KernelArchState struct {
- KernelOpts
-
- // cpuEntries is array of kernelEntry for all cpus
+ // cpuEntries is array of kernelEntry for all cpus.
cpuEntries []kernelEntry
// globalIDT is our set of interrupt gates.
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index 508236e46..a014dcbc0 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -32,15 +32,8 @@ var (
KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
)
-// KernelOpts has initialization options for the kernel.
-type KernelOpts struct {
- // PageTables are the kernel pagetables; this must be provided.
- PageTables *pagetables.PageTables
-}
-
// KernelArchState contains architecture-specific state.
type KernelArchState struct {
- KernelOpts
}
// CPUArchState contains CPU-specific arch state.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index f9278b653..f489ad352 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -288,6 +288,10 @@
#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+/* ISS field definitions for system error */
+#define ESR_ELx_SERR_MASK (0x1)
+#define ESR_ELx_SERR_NMI (0x1)
+
// LOAD_KERNEL_ADDRESS loads a kernel address.
#define LOAD_KERNEL_ADDRESS(from, to) \
MOVD from, to; \
@@ -691,7 +695,7 @@ el0_sp_pc:
B ·Shutdown(SB)
el0_undef:
- EXCEPTION_WITH_ERROR(1, El0Sync_undef)
+ EXCEPTION_WITH_ERROR(1, El0SyncUndef)
el0_dbg:
B ·Shutdown(SB)
@@ -707,6 +711,29 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0
TEXT ·El0_error(SB),NOSPLIT,$0
KERNEL_ENTRY_FROM_EL0
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ AND $ESR_ELx_SERR_MASK, R25, R24
+ CMP $ESR_ELx_SERR_NMI, R24
+ BEQ el0_nmi
+ B el0_bounce
+el0_nmi:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $El0ErrNMI, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ MRS ESR_EL1, R3
+ MOVD R3, CPU_ERROR_CODE(RSV_REG)
+
+ B ·kernelExitToEl1(SB)
+
+el0_bounce:
WORD $0xd538d092 //MRS TPIDR_EL1, R18
WORD $0xd538601a //MRS FAR_EL1, R26
@@ -718,7 +745,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0
MOVD $VirtualizationException, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·HaltAndResume(SB)
+ B ·kernelExitToEl1(SB)
TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 264be23d3..292f9d0cc 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -16,11 +16,9 @@ package ring0
// Init initializes a new kernel.
//
-// N.B. that constraints on KernelOpts must be satisfied.
-//
//go:nosplit
-func (k *Kernel) Init(opts KernelOpts, maxCPUs int) {
- k.init(opts, maxCPUs)
+func (k *Kernel) Init(maxCPUs int) {
+ k.init(maxCPUs)
}
// Halt halts execution.
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index 3a9dff4cc..b55dc29b3 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -24,10 +24,7 @@ import (
)
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
- // Save the root page tables.
- k.PageTables = opts.PageTables
-
+func (k *Kernel) init(maxCPUs int) {
entrySize := reflect.TypeOf(kernelEntry{}).Size()
var (
entries []kernelEntry
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index b294ccc7c..6cbbf001f 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -25,9 +25,7 @@ func HaltAndResume()
func HaltEl1SvcAndResume()
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
- // Save the root page tables.
- k.PageTables = opts.PageTables
+func (k *Kernel) init(maxCPUs int) {
}
// init initializes architecture-specific state.
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 45eba960d..53bc3353c 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -47,43 +47,36 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
fmt.Fprintf(w, "\n// Vectors.\n")
- fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid)
- fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid)
- fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid)
- fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid)
fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync)
fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq)
fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq)
- fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error)
+ fmt.Fprintf(w, "#define El1Err 0x%02x\n", El1Err)
fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync)
fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq)
fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq)
- fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error)
+ fmt.Fprintf(w, "#define El0Err 0x%02x\n", El0Err)
- fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid)
- fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid)
- fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid)
- fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid)
+ fmt.Fprintf(w, "#define El1SyncDa 0x%02x\n", El1SyncDa)
+ fmt.Fprintf(w, "#define El1SyncIa 0x%02x\n", El1SyncIa)
+ fmt.Fprintf(w, "#define El1SyncSpPc 0x%02x\n", El1SyncSpPc)
+ fmt.Fprintf(w, "#define El1SyncUndef 0x%02x\n", El1SyncUndef)
+ fmt.Fprintf(w, "#define El1SyncDbg 0x%02x\n", El1SyncDbg)
+ fmt.Fprintf(w, "#define El1SyncInv 0x%02x\n", El1SyncInv)
- fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da)
- fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia)
- fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc)
- fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef)
- fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg)
- fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv)
+ fmt.Fprintf(w, "#define El0SyncSVC 0x%02x\n", El0SyncSVC)
+ fmt.Fprintf(w, "#define El0SyncDa 0x%02x\n", El0SyncDa)
+ fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa)
+ fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc)
+ fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc)
+ fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys)
+ fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc)
+ fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef)
+ fmt.Fprintf(w, "#define El0SyncDbg 0x%02x\n", El0SyncDbg)
+ fmt.Fprintf(w, "#define El0SyncInv 0x%02x\n", El0SyncInv)
- fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc)
- fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da)
- fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia)
- fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc)
- fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc)
- fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys)
- fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc)
- fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef)
- fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg)
- fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv)
+ fmt.Fprintf(w, "#define El0ErrNMI 0x%02x\n", El0ErrNMI)
fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 7f18ac296..bc16a1622 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -30,6 +30,10 @@ type PageTables struct {
Allocator Allocator
// root is the pagetable root.
+ //
+ // For same archs such as amd64, the upper of the PTEs is cloned
+ // from and owned by upperSharedPageTables which are shared among
+ // many PageTables if upperSharedPageTables is not nil.
root *PTEs
// rootPhysical is the cached physical address of the root.
@@ -39,15 +43,52 @@ type PageTables struct {
// archPageTables includes architecture-specific features.
archPageTables
+
+ // upperSharedPageTables represents a read-only shared upper
+ // of the Pagetable. When it is not nil, the upper is not
+ // allowed to be modified.
+ upperSharedPageTables *PageTables
+
+ // upperStart is the start address of the upper portion that
+ // are shared from upperSharedPageTables
+ upperStart uintptr
+
+ // readOnlyShared indicates the Pagetables are read-only and
+ // own the ranges that are shared with other Pagetables.
+ readOnlyShared bool
}
-// New returns new PageTables.
-func New(a Allocator) *PageTables {
+// NewWithUpper returns new PageTables.
+//
+// upperSharedPageTables are used for mapping the upper of addresses,
+// starting at upperStart. These pageTables should not be touched (as
+// invalidations may be incorrect) after they are passed as an
+// upperSharedPageTables. Only when all dependent PageTables are gone
+// may they be used. The intenteded use case is for kernel page tables,
+// which are static and fixed.
+//
+// Precondition: upperStart must be between canonical ranges.
+// Precondition: upperStart must be pgdSize aligned.
+// precondition: upperSharedPageTables must be marked read-only shared.
+func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables {
p := new(PageTables)
p.Init(a)
+ if upperSharedPageTables != nil {
+ if !upperSharedPageTables.readOnlyShared {
+ panic("Only read-only shared pagetables can be used as upper")
+ }
+ p.upperSharedPageTables = upperSharedPageTables
+ p.upperStart = upperStart
+ p.cloneUpperShared()
+ }
return p
}
+// New returns new PageTables.
+func New(a Allocator) *PageTables {
+ return NewWithUpper(a, nil, 0)
+}
+
// mapVisitor is used for map.
type mapVisitor struct {
target uintptr // Input.
@@ -90,6 +131,21 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
//go:nosplit
func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
+ if p.readOnlyShared {
+ panic("Should not modify read-only shared pagetables.")
+ }
+ if uintptr(addr)+length < uintptr(addr) {
+ panic("addr & length overflow")
+ }
+ if p.upperSharedPageTables != nil {
+ // ignore change to the read-only upper shared portion.
+ if uintptr(addr) >= p.upperStart {
+ return false
+ }
+ if uintptr(addr)+length > p.upperStart {
+ length = p.upperStart - uintptr(addr)
+ }
+ }
if !opts.AccessType.Any() {
return p.Unmap(addr, length)
}
@@ -128,12 +184,27 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// True is returned iff there was a previous mapping in the range.
//
-// Precondition: addr & length must be page-aligned.
+// Precondition: addr & length must be page-aligned, their sum must not overflow.
//
// +checkescape:hard,stack
//
//go:nosplit
func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
+ if p.readOnlyShared {
+ panic("Should not modify read-only shared pagetables.")
+ }
+ if uintptr(addr)+length < uintptr(addr) {
+ panic("addr & length overflow")
+ }
+ if p.upperSharedPageTables != nil {
+ // ignore change to the read-only upper shared portion.
+ if uintptr(addr) >= p.upperStart {
+ return false
+ }
+ if uintptr(addr)+length > p.upperStart {
+ length = p.upperStart - uintptr(addr)
+ }
+ }
w := unmapWalker{
pageTables: p,
visitor: unmapVisitor{
@@ -218,3 +289,10 @@ func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts)
w.iterateRange(uintptr(addr), uintptr(addr)+1)
return w.visitor.physical + offset, w.visitor.opts
}
+
+// MarkReadOnlyShared marks the pagetables read-only and can be shared.
+//
+// It is usually used on the pagetables that are used as the upper
+func (p *PageTables) MarkReadOnlyShared() {
+ p.readOnlyShared = true
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index 520161755..a4e416af7 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -24,14 +24,6 @@ import (
// archPageTables is architecture-specific data.
type archPageTables struct {
- // root is the pagetable root for kernel space.
- root *PTEs
-
- // rootPhysical is the cached physical address of the root.
- //
- // This is saved only to prevent constant translation.
- rootPhysical uintptr
-
asid uint16
}
@@ -46,7 +38,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
//
//go:nosplit
func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
- return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+ return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
}
// Bits in page table entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
index 0c153cf8c..e7ab887e5 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -50,5 +50,26 @@ func (p *PageTables) Init(allocator Allocator) {
p.rootPhysical = p.Allocator.PhysicalFor(p.root)
}
+func pgdIndex(upperStart uintptr) uintptr {
+ if upperStart&(pgdSize-1) != 0 {
+ panic("upperStart should be pgd size aligned")
+ }
+ if upperStart >= upperBottom {
+ return entriesPerPage/2 + (upperStart-upperBottom)/pgdSize
+ }
+ if upperStart < lowerTop {
+ return upperStart / pgdSize
+ }
+ panic("upperStart should be in canonical range")
+}
+
+// cloneUpperShared clone the upper from the upper shared page tables.
+//
+//go:nosplit
+func (p *PageTables) cloneUpperShared() {
+ start := pgdIndex(p.upperStart)
+ copy(p.root[start:entriesPerPage], p.upperSharedPageTables.root[start:entriesPerPage])
+}
+
// PTEs is a collection of entries.
type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
index 5ddd10256..5392bf27a 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -49,8 +49,17 @@ func (p *PageTables) Init(allocator Allocator) {
p.Allocator = allocator
p.root = p.Allocator.NewPTEs()
p.rootPhysical = p.Allocator.PhysicalFor(p.root)
- p.archPageTables.root = p.Allocator.NewPTEs()
- p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+}
+
+// cloneUpperShared clone the upper from the upper shared page tables.
+//
+//go:nosplit
+func (p *PageTables) cloneUpperShared() {
+ if p.upperStart != upperBottom {
+ panic("upperStart should be the same as upperBottom")
+ }
+
+ // nothing to do for arm.
}
// PTEs is a collection of entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
index c261d393a..157c9a7cc 100644
--- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr {
func (w *Walker) iterateRangeCanonical(start, end uintptr) {
pgdEntryIndex := w.pageTables.root
if start >= upperBottom {
- pgdEntryIndex = w.pageTables.archPageTables.root
+ pgdEntryIndex = w.pageTables.upperSharedPageTables.root
}
for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 549787955..e0976fed0 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf
// marshalTarget and unmarshalTarget can be used.
type targetMaker interface {
// id uniquely identifies the target.
- id() stack.TargetID
+ id() targetID
- // marshal converts from a stack.Target to an ABI struct.
- marshal(target stack.Target) []byte
+ // marshal converts from a target to an ABI struct.
+ marshal(target target) []byte
- // unmarshal converts from the ABI matcher struct to a stack.Target.
- unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error)
+ // unmarshal converts from the ABI matcher struct to a target.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error)
}
-// targetMakers maps the TargetID of supported targets to the targetMaker that
+// A targetID uniquely identifies a target.
+type targetID struct {
+ // name is the target name as stored in the xt_entry_target struct.
+ name string
+
+ // networkProtocol is the protocol to which the target applies.
+ networkProtocol tcpip.NetworkProtocolNumber
+
+ // revision is the version of the target.
+ revision uint8
+}
+
+// target extends a stack.Target, allowing it to be used with the extension
+// system. The sentry only uses targets, never stack.Targets directly.
+type target interface {
+ stack.Target
+ id() targetID
+}
+
+// targetMakers maps the targetID of supported targets to the targetMaker that
// marshals and unmarshals it. It is immutable after package initialization.
-var targetMakers = map[stack.TargetID]targetMaker{}
+var targetMakers = map[targetID]targetMaker{}
func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) {
- tid := stack.TargetID{
- Name: name,
- NetworkProtocol: netProto,
- Revision: rev,
+ tid := targetID{
+ name: name,
+ networkProtocol: netProto,
+ revision: rev,
}
if _, ok := targetMakers[tid]; !ok {
return 0, false
@@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8
// Return the highest supported revision unless rev is higher.
for _, other := range targetMakers {
otherID := other.id()
- if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev {
- rev = uint8(otherID.Revision)
+ if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev {
+ rev = uint8(otherID.revision)
}
}
return rev, true
@@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) {
targetMakers[tm.id()] = tm
}
-func marshalTarget(target stack.Target) []byte {
- targetMaker, ok := targetMakers[target.ID()]
+func marshalTarget(tgt stack.Target) []byte {
+ // The sentry only uses targets, never stack.Targets directly.
+ target := tgt.(target)
+ targetMaker, ok := targetMakers[target.id()]
if !ok {
- panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID()))
+ panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id()))
}
return targetMaker.marshal(target)
}
-func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) {
- tid := stack.TargetID{
- Name: target.Name.String(),
- NetworkProtocol: filter.NetworkProtocol(),
- Revision: target.Revision,
+func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) {
+ tid := targetID{
+ name: target.Name.String(),
+ networkProtocol: filter.NetworkProtocol(),
+ revision: target.Revision,
}
targetMaker, ok := targetMakers[tid]
if !ok {
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index b560fae0d..70c561cce 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), false)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct.
- entries, info := getEntries4(table, tablename)
+ entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 4253f7bf4..5dbb604f0 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), true)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct, which is the same in IPv4 and IPv6.
- entries, info := getEntries6(table, tablename)
+ entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 904a12e38..b283d7229 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) {
}
}
+// Table names.
+const (
+ natTable = "nat"
+ mangleTable = "mangle"
+ filterTable = "filter"
+)
+
+// nameToID is immutable.
+var nameToID = map[string]stack.TableID{
+ natTable: stack.NATID,
+ mangleTable: stack.MangleID,
+ filterTable: stack.FilterID,
+}
+
+// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
+// compatibility with netfilter extensions.
+func DefaultLinuxTables() *stack.IPTables {
+ tables := stack.DefaultTables()
+ tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
+ switch val := oldTarget.(type) {
+ case *stack.AcceptTarget:
+ return &acceptTarget{AcceptTarget: *val}
+ case *stack.DropTarget:
+ return &dropTarget{DropTarget: *val}
+ case *stack.ErrorTarget:
+ return &errorTarget{ErrorTarget: *val}
+ case *stack.UserChainTarget:
+ return &userChainTarget{UserChainTarget: *val}
+ case *stack.ReturnTarget:
+ return &returnTarget{ReturnTarget: *val}
+ case *stack.RedirectTarget:
+ return &redirectTarget{RedirectTarget: *val}
+ default:
+ panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val))
+ }
+ })
+ return tables
+}
+
// GetInfo returns information about iptables.
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
@@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.FilterTable:
+ case filterTable:
table = stack.EmptyFilterTable()
- case stack.NATTable:
+ case natTable:
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
@@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
if offset == replace.Underflow[hook] {
if !validUnderflow(table.Rules[ruleIdx], ipv6) {
- nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx)
return syserr.ErrInvalidArgument
}
table.Underflows[hk] = ruleIdx
@@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6))
-
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
@@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool {
return false
}
switch rule.Target.(type) {
- case *stack.AcceptTarget, *stack.DropTarget:
+ case *acceptTarget, *dropTarget:
return true
default:
return false
@@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
if !validUnderflow(rule, ipv6) {
return false
}
- _, ok := rule.Target.(*stack.AcceptTarget)
+ _, ok := rule.Target.(*acceptTarget)
return ok
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 0e14447fe..f2653d523 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -26,6 +26,15 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// ErrorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const ErrorTargetName = "ERROR"
+
+// RedirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port and/or IP for packets.
+const RedirectTargetName = "REDIRECT"
+
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@@ -52,25 +61,96 @@ func init() {
})
}
+// The stack package provides some basic, useful targets for us. The following
+// types wrap them for compatibility with the extension system.
+
+type acceptTarget struct {
+ stack.AcceptTarget
+}
+
+func (at *acceptTarget) id() targetID {
+ return targetID{
+ networkProtocol: at.NetworkProtocol,
+ }
+}
+
+type dropTarget struct {
+ stack.DropTarget
+}
+
+func (dt *dropTarget) id() targetID {
+ return targetID{
+ networkProtocol: dt.NetworkProtocol,
+ }
+}
+
+type errorTarget struct {
+ stack.ErrorTarget
+}
+
+func (et *errorTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: et.NetworkProtocol,
+ }
+}
+
+type userChainTarget struct {
+ stack.UserChainTarget
+}
+
+func (uc *userChainTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: uc.NetworkProtocol,
+ }
+}
+
+type returnTarget struct {
+ stack.ReturnTarget
+}
+
+func (rt *returnTarget) id() targetID {
+ return targetID{
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
+type redirectTarget struct {
+ stack.RedirectTarget
+
+ // addr must be (un)marshalled when reading and writing the target to
+ // userspace, but does not affect behavior.
+ addr tcpip.Address
+}
+
+func (rt *redirectTarget) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (sm *standardTargetMaker) id() stack.TargetID {
+func (sm *standardTargetMaker) id() targetID {
// Standard targets have the empty string as a name and no revisions.
- return stack.TargetID{
- NetworkProtocol: sm.NetworkProtocol,
+ return targetID{
+ networkProtocol: sm.NetworkProtocol,
}
}
-func (*standardTargetMaker) marshal(target stack.Target) []byte {
+
+func (*standardTargetMaker) marshal(target target) []byte {
// Translate verdicts the same way as the iptables tool.
var verdict int32
switch tg := target.(type) {
- case *stack.AcceptTarget:
+ case *acceptTarget:
verdict = -linux.NF_ACCEPT - 1
- case *stack.DropTarget:
+ case *dropTarget:
verdict = -linux.NF_DROP - 1
- case *stack.ReturnTarget:
+ case *returnTarget:
verdict = linux.NF_RETURN
case *JumpTarget:
verdict = int32(tg.Offset)
@@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTStandardTarget {
nflog("buf has wrong size for standard target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -114,20 +194,20 @@ type errorTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (em *errorTargetMaker) id() stack.TargetID {
+func (em *errorTargetMaker) id() targetID {
// Error targets have no revision.
- return stack.TargetID{
- Name: stack.ErrorTargetName,
- NetworkProtocol: em.NetworkProtocol,
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: em.NetworkProtocol,
}
}
-func (*errorTargetMaker) marshal(target stack.Target) []byte {
+func (*errorTargetMaker) marshal(target target) []byte {
var errorName string
switch tg := target.(type) {
- case *stack.ErrorTarget:
- errorName = stack.ErrorTargetName
- case *stack.UserChainTarget:
+ case *errorTarget:
+ errorName = ErrorTargetName
+ case *userChainTarget:
errorName = tg.Name
default:
panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target))
@@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte {
},
}
copy(xt.Name[:], errorName)
- copy(xt.Target.Name[:], stack.ErrorTargetName)
+ copy(xt.Target.Name[:], ErrorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTErrorTarget {
nflog("buf has insufficient size for error target %d", len(buf))
return nil, syserr.ErrInvalidArgument
}
- var errorTarget linux.XTErrorTarget
+ var errTgt linux.XTErrorTarget
buf = buf[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &errTgt)
// Error targets are used in 2 cases:
- // * An actual error case. These rules have an error
- // named stack.ErrorTargetName. The last entry of the table
- // is usually an error case to catch any packets that
- // somehow fall through every rule.
+ // * An actual error case. These rules have an error named
+ // ErrorTargetName. The last entry of the table is usually an error
+ // case to catch any packets that somehow fall through every rule.
// * To mark the start of a user defined chain. These
// rules have an error with the name of the chain.
- switch name := errorTarget.Name.String(); name {
- case stack.ErrorTargetName:
- return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil
+ switch name := errTgt.Name.String(); name {
+ case ErrorTargetName:
+ return &errorTarget{stack.ErrorTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}, nil
default:
// User defined chain.
- return &stack.UserChainTarget{
+ return &userChainTarget{stack.UserChainTarget{
Name: name,
NetworkProtocol: filter.NetworkProtocol(),
- }, nil
+ }}, nil
}
}
@@ -178,22 +259,22 @@ type redirectTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *redirectTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *redirectTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*redirectTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*redirectTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
// This is a redirect target named redirect
xt := linux.XTRedirectTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTRedirectTarget,
},
}
- copy(xt.Target.Name[:], stack.RedirectTargetName)
+ copy(xt.Target.Name[:], RedirectTargetName)
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
xt.NfRange.RangeSize = 1
@@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) < linux.SizeOfXTRedirectTarget {
nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- var redirectTarget linux.XTRedirectTarget
+ var rt linux.XTRedirectTarget
buf = buf[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &rt)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
- target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()}
+ target := redirectTarget{RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}
// RangeSize should be 1.
- nfRange := redirectTarget.NfRange
+ nfRange := rt.NfRange
if nfRange.RangeSize != 1 {
nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize)
return nil, syserr.ErrInvalidArgument
@@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
target.Port = ntohs(nfRange.RangeIPV4.MinPort)
return &target, nil
@@ -264,15 +347,15 @@ type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *nfNATTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *nfNATTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*nfNATTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
TargetSize: nfNATMarhsalledSize,
@@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
},
}
- copy(nt.Target.Name[:], stack.RedirectTargetName)
- copy(nt.Range.MinAddr[:], rt.Addr)
- copy(nt.Range.MaxAddr[:], rt.Addr)
+ copy(nt.Target.Name[:], RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.addr)
+ copy(nt.Range.MaxAddr[:], rt.addr)
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
@@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, nt)
}
-func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if size := nfNATMarhsalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
@@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
return nil, syserr.ErrInvalidArgument
}
- target := stack.RedirectTarget{
- NetworkProtocol: filter.NetworkProtocol(),
- Addr: tcpip.Address(natRange.MinAddr[:]),
- Port: ntohs(natRange.MinProto),
+ target := redirectTarget{
+ RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Port: ntohs(natRange.MinProto),
+ },
+ addr: tcpip.Address(natRange.MinAddr[:]),
}
return &target, nil
@@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
-func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) {
+func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return &stack.AcceptTarget{NetworkProtocol: netProto}, nil
+ return &acceptTarget{stack.AcceptTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_DROP - 1:
- return &stack.DropTarget{NetworkProtocol: netProto}, nil
+ return &dropTarget{stack.DropTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_QUEUE - 1:
nflog("unsupported iptables verdict QUEUE")
return nil, syserr.ErrInvalidArgument
case linux.NF_RETURN:
- return &stack.ReturnTarget{NetworkProtocol: netProto}, nil
+ return &returnTarget{stack.ReturnTarget{
+ NetworkProtocol: netProto,
+ }}, nil
default:
nflog("unknown iptables verdict %d", val)
return nil, syserr.ErrInvalidArgument
@@ -382,9 +473,9 @@ type JumpTarget struct {
}
// ID implements Target.ID.
-func (jt *JumpTarget) ID() stack.TargetID {
- return stack.TargetID{
- NetworkProtocol: jt.NetworkProtocol,
+func (jt *JumpTarget) id() targetID {
+ return targetID{
+ networkProtocol: jt.NetworkProtocol,
}
}
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index 22216158e..f4d034c13 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -487,7 +487,7 @@ func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlin
Addr: value,
})
if err != nil {
- return syserr.ErrInvalidArgument
+ return syserr.ErrBadLocalAddress
}
case linux.IFA_ADDRESS:
default:
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index aa4f3c04d..6d9e502bd 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -142,9 +142,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
}
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
- q1.EnableLeakCheck()
+ q1.InitRefs()
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
- q2.EnableLeakCheck()
+ q2.InitRefs()
if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
@@ -300,14 +300,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
}
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
- readQueue.EnableLeakCheck()
+ readQueue.InitRefs()
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
- writeQueue.EnableLeakCheck()
+ writeQueue.InitRefs()
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index f8aacca13..1406971bc 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -42,7 +42,7 @@ var (
func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
- q.EnableLeakCheck()
+ q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
return ep
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index adad485a9..b32bb7ba8 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -80,7 +80,7 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
stype: stype,
},
}
- s.EnableLeakCheck()
+ s.InitRefs()
return fs.NewFile(ctx, d, flags, &s)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 7a78444dc..eaf0b0d26 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -80,7 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
- sock.EnableLeakCheck()
+ sock.InitRefs()
sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 36902d177..bb1f715e2 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 849a47476..f7135ea46 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -32,7 +32,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
return 0, syserror.EINVAL
}
- r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize)
+ r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize)
r.SetFlags(linuxToFlags(flags).Settable())
defer r.DecRef(t)
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index c2d4bf805..e383a0a87 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -138,12 +138,18 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
+ case linux.GETZCNT:
+ v, err := getZCnt(t, id, num)
+ return uintptr(v), nil, err
+
+ case linux.GETNCNT:
+ v, err := getNCnt(t, id, num)
+ return uintptr(v), nil, err
+
case linux.IPC_INFO,
linux.SEM_INFO,
linux.SEM_STAT,
- linux.SEM_STAT_ANY,
- linux.GETNCNT,
- linux.GETZCNT:
+ linux.SEM_STAT_ANY:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
@@ -258,3 +264,23 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
}
return int32(tg.ID()), nil
}
+
+func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.CountZeroWaiters(num, creds)
+}
+
+func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.CountNegativeWaiters(num, creds)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 46616c961..1c4cdb0dd 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -41,6 +41,7 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
inCh chan struct{}
outCh chan struct{}
)
+
for opts.Length > 0 {
n, err = fs.Splice(t, outFile, inFile, opts)
opts.Length -= n
@@ -61,23 +62,28 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
inW, _ := waiter.NewChannelEntry(inCh)
inFile.EventRegister(&inW, EventMaskRead)
defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
+ // Need to refresh readiness.
+ continue
}
if err = t.Block(inCh); err != nil {
break
}
}
- if outFile.Readiness(EventMaskWrite) == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, EventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(outCh); err != nil {
- break
- }
+ // Don't bother checking readiness of the outFile, because it's not a
+ // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds
+ // can be "ready" but will reject writes of certain sizes with
+ // EWOULDBLOCK.
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ // We might be ready to write now. Try again before
+ // blocking.
+ continue
+ }
+ if err = t.Block(outCh); err != nil {
+ break
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 035e2a6b0..9ce4f280a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -480,18 +480,17 @@ func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
// waitForOut waits for dw.outfile to be read.
func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
- if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if dw.outCh == nil {
- dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
- dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
- // We might be ready now. Try again before blocking.
- return nil
- }
- if err := t.Block(dw.outCh); err != nil {
- return err
- }
- }
- return nil
+ // Don't bother checking readiness of the outFile, because it's not a
+ // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds
+ // can be "ready" but will reject writes of certain sizes with
+ // EWOULDBLOCK. See b/172075629, b/170743336.
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready to write now. Try again before blocking.
+ return nil
+ }
+ return t.Block(dw.outCh)
}
// destroy cleans up resources help by dw. No more calls to wait* can occur
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 546e445aa..936f9fc71 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -133,7 +133,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou
}
}
- fd.EnableLeakCheck()
+ fd.InitRefs()
// Remove "file creation flags" to mirror the behavior from file.f_flags in
// fs/open.c:do_dentry_open.
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index c93d94634..2c4b81e78 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -48,7 +48,7 @@ type Filesystem struct {
// Init must be called before first use of fs.
func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) {
- fs.EnableLeakCheck()
+ fs.InitRefs()
fs.vfs = vfsObj
fs.fsType = fsType
fs.impl = impl
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 3f0b8f45b..107171b61 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -65,7 +65,7 @@ type Inotify struct {
// queue is used to notify interested parties when the inotify instance
// becomes readable or writable.
- queue waiter.Queue `state:"nosave"`
+ queue waiter.Queue
// evMu *only* protects the events list. We need a separate lock while
// queuing events: using mu may violate lock ordering, since at that point
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 3ea981ad4..d865fd603 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -169,7 +169,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
Owner: creds.UserNamespace,
mountpoints: make(map[*Dentry]uint32),
}
- mntns.EnableLeakCheck()
+ mntns.InitRefs()
mntns.root = newMount(vfs, fs, root, mntns, opts)
return mntns, nil
}
@@ -477,7 +477,9 @@ func (mnt *Mount) tryIncMountedRef() bool {
return false
}
if atomic.CompareAndSwapInt64(&mnt.refs, r, r+1) {
- refsvfs2.LogTryIncRef(mnt, r+1)
+ if mnt.LogRefs() {
+ refsvfs2.LogTryIncRef(mnt, r+1)
+ }
return true
}
}
@@ -488,12 +490,17 @@ func (mnt *Mount) IncRef() {
// In general, negative values for mnt.refs are valid because the MSB is
// the eager-unmount bit.
r := atomic.AddInt64(&mnt.refs, 1)
- refsvfs2.LogIncRef(mnt, r)
+ if mnt.LogRefs() {
+ refsvfs2.LogIncRef(mnt, r)
+ }
}
// DecRef decrements mnt's reference count.
func (mnt *Mount) DecRef(ctx context.Context) {
r := atomic.AddInt64(&mnt.refs, -1)
+ if mnt.LogRefs() {
+ refsvfs2.LogDecRef(mnt, r)
+ }
if r&^math.MinInt64 == 0 { // mask out MSB
refsvfs2.Unregister(mnt)
mnt.destroy(ctx)
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index cb8c56bd3..cb882a983 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -29,7 +29,7 @@ func TestMountTableLookupEmpty(t *testing.T) {
parent := &Mount{}
point := &Dentry{}
if m := mt.Lookup(parent, point); m != nil {
- t.Errorf("empty mountTable lookup: got %p, wanted nil", m)
+ t.Errorf("Empty mountTable lookup: got %p, wanted nil", m)
}
}
@@ -111,13 +111,16 @@ func BenchmarkMountTableParallelLookup(b *testing.B) {
k := keys[i&(numMounts-1)]
m := mt.Lookup(k.mount, k.dentry)
if m == nil {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -167,13 +170,16 @@ func BenchmarkMountMapParallelLookup(b *testing.B) {
m := ms[k]
mu.RUnlock()
if m == nil {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -220,14 +226,17 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) {
k := keys[i&(numMounts-1)]
mi, ok := ms.Load(k)
if !ok {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
m := mi.(*Mount)
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -264,7 +273,7 @@ func BenchmarkMountTableNegativeLookup(b *testing.B) {
k := negkeys[i&(numMounts-1)]
m := mt.Lookup(k.mount, k.dentry)
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
@@ -300,7 +309,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) {
m := ms[k]
mu.RUnlock()
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
@@ -333,7 +342,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) {
k := negkeys[i&(numMounts-1)]
m, _ := ms.Load(k)
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD
index f08599ebd..cb0001852 100644
--- a/pkg/shim/runsc/BUILD
+++ b/pkg/shim/runsc/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "@com_github_containerd_containerd//log:go_default_library",
"@com_github_containerd_go_runc//:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go
index c5cf68efa..e7c9640ba 100644
--- a/pkg/shim/runsc/runsc.go
+++ b/pkg/shim/runsc/runsc.go
@@ -28,10 +28,12 @@ import (
"syscall"
"time"
+ "github.com/containerd/containerd/log"
runc "github.com/containerd/go-runc"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
+// Monitor is the default process monitor to be used by runsc.
var Monitor runc.ProcessMonitor = runc.Monitor
// DefaultCommand is the default command for Runsc.
@@ -74,6 +76,7 @@ func (r *Runsc) State(context context.Context, id string) (*runc.Container, erro
return &c, nil
}
+// CreateOpts is a set of options to Runsc.Create().
type CreateOpts struct {
runc.IO
ConsoleSocket runc.ConsoleSocket
@@ -197,6 +200,7 @@ func (r *Runsc) Wait(context context.Context, id string) (int, error) {
return res.ExitStatus, nil
}
+// ExecOpts is a set of options to runsc.Exec().
type ExecOpts struct {
runc.IO
PidFile string
@@ -301,6 +305,7 @@ func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts
return Monitor.Wait(cmd, ec)
}
+// DeleteOpts is a set of options to runsc.Delete().
type DeleteOpts struct {
Force bool
}
@@ -367,6 +372,13 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) {
if err := json.NewDecoder(rd).Decode(&e); err != nil {
return nil, err
}
+ log.L.Debugf("Stats returned: %+v", e.Stats)
+ if e.Type != "stats" {
+ return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type)
+ }
+ if e.Stats == nil {
+ return nil, fmt.Errorf(`"runsc events -stat" succeeded but no stat was provided`)
+ }
return e.Stats, nil
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 09cb153b1..4e7e5f76a 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -375,6 +375,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback
+// address.
+func IsV6LoopbackAddress(addr tcpip.Address) bool {
+ return addr == IPv6Loopback
+}
+
// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
// link-local multicast address.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 560477926..b3e8c4b92 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -205,7 +205,12 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
//
// We don't clone the original packet buffer so that the new packet buffer
// does not have any of its headers set.
- pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())})
+ //
+ // We trim the link headers from the cloned buffer as the sniffer doesn't
+ // handle link headers.
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ vv.TrimFront(len(pkt.LinkHeader().View()))
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
switch protocol {
case header.IPv4ProtocolNumber:
if ok := parse.IPv4(pkt); !ok {
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index cda6328a2..4c14f55d3 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -157,7 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
name: name,
isTap: prefix == "tap",
}
- endpoint.EnableLeakCheck()
+ endpoint.InitRefs()
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
endpoint.name = fmt.Sprintf("%s%d", prefix, id)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index a79379abb..33a4a0720 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -122,7 +122,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
@@ -145,7 +145,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
@@ -158,6 +158,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
+ respPkt.NetworkProtocolNumber = ProtocolNumber
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
// TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 969579601..8873bd91f 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -110,8 +110,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
- t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
+func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
+ netHdr := pkt.Network()
+ t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress())
t.dataCalls++
return stack.TransportPacketHandled
}
@@ -608,7 +609,8 @@ func TestIPv4Receive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -707,7 +709,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.typ = c.expectedTyp
nic.testObject.extra = c.expectedExtra
- ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
+ pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
@@ -788,7 +792,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
@@ -800,7 +805,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -900,7 +906,8 @@ func TestIPv6Receive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -1017,7 +1024,9 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
+ pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
@@ -1071,7 +1080,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum tcpip.NetworkProtocolNumber
nicAddr tcpip.Address
remoteAddr tcpip.Address
- pktGen func(*testing.T, tcpip.Address) buffer.View
+ pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
expectedErr *tcpip.Error
}{
@@ -1081,7 +1090,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1095,7 +1104,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1123,7 +1132,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1137,7 +1146,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1147,7 +1156,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -1156,7 +1165,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip[:len(ip)-1])
+ return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1166,7 +1175,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -1175,7 +1184,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip)
+ return buffer.View(ip).ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1203,7 +1212,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
@@ -1221,7 +1230,49 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
}
- return hdr.View()
+ return hdr.View().ToVectorisedView()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options and data across views",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(header.IPv4MinimumSize + len(ipv4Options)),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ vv := buffer.View(ip).ToVectorisedView()
+ vv.AppendView(ipv4Options)
+ vv.AppendView(data)
+ return vv
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1251,7 +1302,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1264,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1291,7 +1342,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1307,7 +1358,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1334,7 +1385,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
NextHeader: transportProto,
@@ -1342,7 +1393,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip)
+ return buffer.View(ip).ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1369,7 +1420,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
NextHeader: transportProto,
@@ -1377,7 +1428,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip[:len(ip)-1])
+ return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1421,7 +1472,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
defer r.Release()
if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ Data: test.pktGen(t, subTest.srcAddr),
})); err != test.expectedErr {
t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index cf287446e..9b5e37fee 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -42,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
- src := hdr.SourceAddress()
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
+ srcAddr := hdr.SourceAddress()
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 {
return
}
@@ -58,11 +58,11 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
- stats := r.Stats()
+func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
received := stats.ICMP.V4PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
@@ -83,7 +83,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// packets with checksum errors.
switch h.Type() {
case header.ICMPv4Echo:
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
}
return
}
@@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
} else {
op = &optionUsageReceive{}
}
- aux, tmp, err := processIPOptions(r, iph.Options(), op)
+ aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op)
if err != nil {
switch {
case
@@ -116,9 +116,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
errors.Is(err, errIPv4TimestampOptInvalidLength),
errors.Is(err, errIPv4TimestampOptInvalidPointer),
errors.Is(err, errIPv4TimestampOptOverflow):
- _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
+ stats.MalformedRcvdPackets.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
}
return
}
@@ -131,7 +131,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Echo.Increment()
sent := stats.ICMP.V4PacketsSent
- if !r.Stack().AllowICMPMessage() {
+ if !e.protocol.stack.AllowICMPMessage() {
sent.RateLimited.Increment()
return
}
@@ -144,10 +144,13 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
replyData := pkt.Data.ToOwnedView()
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast
// It's possible that a raw socket expects to receive this.
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
pkt = nil
+
// Take the base of the incoming request IP header but replace the options.
replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions))
replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...))
@@ -156,12 +159,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
- localAddr := r.LocalAddress
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) {
+ localAddr := ipHdr.DestinationAddress()
+ if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -218,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.EchoReply.Increment()
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
received.DstUnreachable.Increment()
@@ -307,7 +310,11 @@ func (*icmpReasonParamProblem) isICMPReason() {}
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ origIPHdr := header.IPv4(pkt.NetworkHeader().View())
+ origIPHdrSrc := origIPHdr.SourceAddress()
+ origIPHdrDst := origIPHdr.DestinationAddress()
+
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -331,8 +338,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
//
// TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
// response to a non-initial fragment, but it currently can not happen.
-
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any {
+ if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any {
return nil
}
@@ -340,14 +346,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
- // From this point on, the incoming route should no longer be used; route
- // must be used to send the ICMP error.
- r = nil
sent := p.stack.Stats().ICMP.V4PacketsSent
if !p.stack.AllowICMPMessage() {
@@ -355,11 +358,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
return nil
}
- networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
- if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) {
+ if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) {
// TODO(gvisor.dev/issue/3810):
// Unfortunately the current stack pretty much always has ICMPv4 headers
// in the Data section of the packet but there is no guarantee that is the
@@ -416,7 +418,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
return nil
}
- payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size()
+ payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
@@ -428,7 +430,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// view with the entire incoming IP packet reassembled and truncated as
// required. This is now the payload of the new ICMP packet and no longer
// considered a packet in its own right.
- newHeader := append(buffer.View(nil), networkHeader...)
+ newHeader := append(buffer.View(nil), origIPHdr...)
newHeader = append(newHeader, transportHeader...)
payload := newHeader.ToVectorisedView()
payload.AppendView(pkt.Data.ToView())
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 4592984a5..cfd0c505a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -252,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
@@ -270,16 +269,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
netHeader := header.IPv4(pkt.NetworkHeader().View())
ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.HandlePacket(pkt)
+ }
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, pkt)
- loopedR.Release()
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ loopedR := r.MakeLoopedRoute()
+ loopedR.PopulatePacketInfo(pkt)
+ loopedR.Release()
+ e.HandlePacket(pkt)
+ }
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -373,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
+ }
n++
continue
}
@@ -403,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return tcpip.ErrMalformedHeader
}
+
+ hdrLen := header.IPv4(h).HeaderLength()
+ if hdrLen < header.IPv4MinimumSize {
+ return tcpip.ErrMalformedHeader
+ }
+
+ h, ok = pkt.Data.PullUp(int(hdrLen))
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
ip := header.IPv4(h)
// Always set the total length.
@@ -447,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
+ pkt.NICID = e.nic.ID()
+ stats := e.protocol.stack.Stats()
+
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -480,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if h.CalculateChecksum() != 0xffff {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -488,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
- r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
@@ -498,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesInputDropped.Increment()
+ stats.IP.IPTablesInputDropped.Increment()
return
}
@@ -506,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
// The packet is a fragment, let's try to reassemble it.
@@ -520,8 +545,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
@@ -537,12 +562,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
var releaseCB func(bool)
if start == 0 {
pkt := pkt.Clone()
- r := r.Clone()
releaseCB = func(timedOut bool) {
if timedOut {
- _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
}
- r.Release()
}
}
@@ -566,8 +589,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
releaseCB,
)
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
if !ready {
@@ -579,7 +602,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
h.SetFlagsFragmentOffset(0, 0)
}
- r.Stats().IP.PacketsDelivered.Increment()
+ stats.IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
@@ -587,14 +610,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// headers, the setting of the transport number here should be
// unnecessary and removed.
pkt.TransportProtocolNumber = p
- e.handleICMP(r, pkt)
+ e.handleICMP(pkt)
return
}
if len(h.Options()) != 0 {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
// rather than just throwing them away.
- aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{})
+ aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{})
if err != nil {
switch {
case
@@ -604,15 +627,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
errors.Is(err, errIPv4TimestampOptInvalidLength),
errors.Is(err, errIPv4TimestampOptInvalidPointer),
errors.Is(err, errIPv4TimestampOptOverflow):
- _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt)
- e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
+ stats.MalformedRcvdPackets.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
}
return
}
}
- switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
@@ -620,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC: 1122 Section 3.2.2.1
// A host SHOULD generate Destination Unreachable messages with code:
// 2 (Protocol Unreachable), when the designated transport protocol
// is not supported
- _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -669,7 +692,7 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
loopback := e.nic.IsLoopback()
addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
- subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
@@ -919,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head
originalIPHeaderLength := len(originalIPHeader)
nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
@@ -1172,8 +1196,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// - The location of an error if there was one (or 0 if no error)
// - If there is an error, information as to what it was was.
// - The replacement option set.
-func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
-
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
+ stats := e.protocol.stack.Stats()
opts := header.IPv4Options(orig)
optIter := opts.MakeIterator()
@@ -1186,13 +1210,15 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
// This will need tweaking when we start really forwarding packets
// as we may need to get two addresses, for rx and tx interfaces.
// We will also have to take usage into account.
- prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber)
+ prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
localAddress := prefixedAddress.Address
if err != nil {
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ dstAddr := h.DestinationAddress()
+ if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
}
- localAddress = r.LocalAddress
+ localAddress = dstAddr
}
for {
@@ -1219,9 +1245,9 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
optLen := int(option.Size())
switch option := option.(type) {
case *header.IPv4OptionTimestamp:
- r.Stats().IP.OptionTSReceived.Increment()
+ stats.IP.OptionTSReceived.Increment()
if usage.actions().timestamp != optionRemove {
- clock := r.Stack().Clock()
+ clock := e.protocol.stack.Clock()
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
@@ -1232,7 +1258,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
}
case *header.IPv4OptionRecordRoute:
- r.Stats().IP.OptionRRReceived.Increment()
+ stats.IP.OptionRRReceived.Increment()
if usage.actions().recordRoute != optionRemove {
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
@@ -1244,7 +1270,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag
}
default:
- r.Stats().IP.OptionUnknownReceived.Increment()
+ stats.IP.OptionUnknownReceived.Increment()
if usage.actions().unknown == optionPass {
newBuffer := optIter.RemainingBuffer()[:optLen]
// Arguments already heavily checked.. ignore result.
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 61672a5ff..c7f434591 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -2178,13 +2178,10 @@ func TestWriteStats(t *testing.T) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
@@ -2199,17 +2196,14 @@ func TestWriteStats(t *testing.T) {
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 3c15e41a7..8502b848c 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -124,8 +124,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
})
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
- stats := r.Stats().ICMP
+func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
+ stats := e.protocol.stack.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
@@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
h := header.ICMPv6(v)
iph := header.IPv6(pkt.NetworkHeader().View())
+ srcAddr := iph.SourceAddress()
+ dstAddr := iph.DestinationAddress()
// Validate ICMPv6 checksum before processing the packet.
//
// This copy is used as extra payload during the checksum calculation.
payload := pkt.Data.Clone(nil)
payload.TrimFront(len(h))
- if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
+ if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want {
received.Invalid.Increment()
return
}
@@ -224,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// we know we are also performing DAD on it). In this case we let the
// stack know so it can handle such a scenario and do nothing further with
// the NS.
- if r.RemoteAddress == header.IPv6Any {
+ if srcAddr == header.IPv6Any {
// We would get an error if the address no longer exists or the address
// is no longer tentative (DAD resolved between the call to
// hasTentativeAddr and this point). Both of these are valid scenarios:
@@ -251,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -277,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// Otherwise, on link layers that have addresses this option MUST be
// included in multicast solicitations and SHOULD be included in unicast
// solicitations.
- unspecifiedSource := r.RemoteAddress == header.IPv6Any
+ unspecifiedSource := srcAddr == header.IPv6Any
if len(sourceLinkAddr) == 0 {
- if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource {
received.Invalid.Increment()
return
}
@@ -287,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
received.Invalid.Increment()
return
} else if e.nud != nil {
- e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr)
}
// As per RFC 4861 section 7.1.1:
@@ -298,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// ...
// - If the IP source address is the unspecified address, the IP
// destination address is a solicited-node multicast address.
- if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) {
received.Invalid.Increment()
return
}
@@ -308,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the source of the solicitation is the unspecified address, the node
// MUST [...] and multicast the advertisement to the all-nodes address.
//
- remoteAddr := r.RemoteAddress
+ remoteAddr := srcAddr
if unspecifiedSource {
remoteAddr = header.IPv6AllNodesMulticastAddress
}
@@ -465,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
- localAddr := r.LocalAddress
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ localAddr := dstAddr
+ if header.IsV6MulticastAddress(dstAddr) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -486,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -498,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
received.Invalid.Increment()
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
received.TimeExceeded.Increment()
@@ -519,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- stack := r.Stack()
+ stack := e.protocol.stack
// Is the networking stack operating as a router?
if !stack.Forwarding(ProtocolNumber) {
@@ -550,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST
// NOT be included when the source IP address is the unspecified address.
// Otherwise, it SHOULD be included on link layers that have addresses.
- if r.RemoteAddress == header.IPv6Any {
+ if srcAddr == header.IPv6Any {
received.Invalid.Increment()
return
}
@@ -558,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
if e.nud != nil {
// A RS with a specified source IP address modifies the NUD state
// machine in the same way a reachability probe would.
- e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
}
}
@@ -575,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- routerAddr := iph.SourceAddress()
+ routerAddr := srcAddr
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
@@ -608,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the RA has the source link layer option, update the link address
// cache with the link address for the advertised router.
if len(sourceLinkAddr) != 0 && e.nud != nil {
- e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
}
e.mu.Lock()
@@ -753,7 +759,11 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ origIPHdr := header.IPv6(pkt.NetworkHeader().View())
+ origIPHdrSrc := origIPHdr.SourceAddress()
+ origIPHdrDst := origIPHdr.DestinationAddress()
+
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
@@ -780,7 +790,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
allowResponseToMulticast = reason.respondToMulticast
}
- if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any {
return nil
}
@@ -788,14 +798,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
- // From this point on, the incoming route should no longer be used; route
- // must be used to send the ICMP error.
- r = nil
stats := p.stack.Stats().ICMP
sent := stats.V6PacketsSent
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index aa8b5f2e5..76013daa1 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -87,7 +87,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
+func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
return stack.TransportPacketHandled
}
@@ -282,7 +282,8 @@ func TestICMPCounts(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
for _, typ := range types {
@@ -424,7 +425,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
for _, typ := range types {
@@ -1796,7 +1798,8 @@ func TestCallsToNeighborCache(t *testing.T) {
SrcAddr: r.RemoteAddress,
DstAddr: r.LocalAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
if nudHandler.probeCount != test.wantProbeCount {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 1e38f3a9d..0526190cc 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -166,7 +166,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return err
}
- prefix := addressEndpoint.AddressWithPrefix().Subnet()
+ prefix := addressEndpoint.Subnet()
switch t := addressEndpoint.ConfigType(); t {
case stack.AddressConfigStatic:
@@ -465,21 +465,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.HandlePacket(pkt)
+ }
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- loopedR := r.MakeLoopedRoute()
-
- e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{
- // The inbound path expects an unparsed packet.
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- }))
-
- loopedR.Release()
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ loopedR := r.MakeLoopedRoute()
+ loopedR.PopulatePacketInfo(pkt)
+ loopedR.Release()
+ e.HandlePacket(pkt)
+ }
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -576,10 +582,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
+ }
n++
continue
}
@@ -637,22 +645,27 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
+ pkt.NICID = e.nic.ID()
+ stats := e.protocol.stack.Stats()
+
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
// As per RFC 4291 section 2.7:
// Multicast addresses must not be used as source addresses in IPv6
// packets or appear in any Routing header.
- if header.IsV6MulticastAddress(r.RemoteAddress) {
- r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ if header.IsV6MulticastAddress(srcAddr) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
@@ -671,7 +684,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesInputDropped.Increment()
+ stats.IP.IPTablesInputDropped.Increment()
return
}
@@ -681,7 +694,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -693,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
if previousHeaderStart != 0 {
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: previousHeaderStart,
}, pkt)
@@ -705,7 +718,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -719,7 +732,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionDiscard:
return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ if header.IsV6MulticastAddress(dstAddr) {
return
}
fallthrough
@@ -732,7 +745,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -757,7 +770,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
if extHdr.SegmentsLeft() != 0 {
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -794,8 +807,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
it, done, err := it.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -822,8 +835,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
switch lastHdr.(type) {
case header.IPv6RawPayloadHeader:
default:
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
}
@@ -831,8 +844,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
fragmentPayloadLen := rawPayload.Buf.Size()
if fragmentPayloadLen == 0 {
// Drop the packet as it's marked as a fragment but has no payload.
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
@@ -845,9 +858,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// of the fragment, pointing to the Payload Length field of the
// fragment packet.
if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: header.IPv6PayloadLenOffset,
}, pkt)
@@ -866,9 +879,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// the fragment, pointing to the Fragment Offset field of the fragment
// packet.
if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: fragmentFieldOffset,
}, pkt)
@@ -880,12 +893,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
var releaseCB func(bool)
if start == 0 {
pkt := pkt.Clone()
- r := r.Clone()
releaseCB = func(timedOut bool) {
if timedOut {
- _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
}
- r.Release()
}
}
@@ -895,8 +906,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// IPv6 ignores the Protocol field since the ID only needs to be unique
// across source-destination pairs, as per RFC 8200 section 4.5.
fragmentation.FragmentID{
- Source: h.SourceAddress(),
- Destination: h.DestinationAddress(),
+ Source: srcAddr,
+ Destination: dstAddr,
ID: extHdr.ID(),
},
start,
@@ -907,8 +918,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
releaseCB,
)
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
pkt.Data = data
@@ -927,7 +938,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -941,7 +952,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionDiscard:
return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ if header.IsV6MulticastAddress(dstAddr) {
return
}
fallthrough
@@ -954,7 +965,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -977,13 +988,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
- r.Stats().IP.PacketsDelivered.Increment()
+ stats.IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
pkt.TransportProtocolNumber = p
- e.handleICMP(r, pkt, hasFragmentHeader)
+ e.handleICMP(pkt, hasFragmentHeader)
} else {
- r.Stats().IP.PacketsDelivered.Increment()
- switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ stats.IP.PacketsDelivered.Increment()
+ switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC 4443 section 3.1:
@@ -991,7 +1002,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC 8200 section 4. (page 7):
// Extension headers are numbered from IANA IP Protocol Numbers
@@ -1012,7 +1023,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -1022,11 +1033,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
default:
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
- r.Stats().UnknownProtocolRcvdPackets.Increment()
+ stats.UnknownProtocolRcvdPackets.Increment()
return
}
}
@@ -1635,6 +1646,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
originalIPHeadersLength := len(originalIPHeaders)
fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index c593c0004..1bfcdde25 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -2360,13 +2360,10 @@ func TestWriteStats(t *testing.T) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
@@ -2381,17 +2378,14 @@ func TestWriteStats(t *testing.T) {
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 7f2ebc0cb..981d1371a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -573,6 +573,13 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
}
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
@@ -993,7 +1000,8 @@ func TestNDPValidation(t *testing.T) {
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
- ep.HandlePacket(r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
var tllData [header.NDPLinkLayerAddressSize]byte
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 261705575..9478f3fb7 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
addrState = &addressState{
addressableEndpointState: a,
addr: addr,
+ // Cache the subnet in addrState to avoid calls to addr.Subnet() as that
+ // results in allocations on every call.
+ subnet: addr.Subnet(),
}
a.mu.endpoints[addr.Address] = addrState
addrState.mu.Lock()
@@ -666,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil)
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
-
+ subnet tcpip.Subnet
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
@@ -686,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
}
+// Subnet implements AddressEndpoint.
+func (a *addressState) Subnet() tcpip.Subnet {
+ return a.subnet
+}
+
// GetKind implements AddressEndpoint.
func (a *addressState) GetKind() AddressKind {
a.mu.RLock()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 0cd1da11f..9a17efcba 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.Addr
- replyTID.srcPort = rt.Port
+ replyTID.srcAddr = address
+ replyTID.srcPort = port
var manip manipType
switch hook {
case Prerouting:
@@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ length := uint16(len(tcpHeader) + pkt.Data.Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
- } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 380688038..7a501acdc 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
+func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 8d6d9a7f1..2d8c883cd 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,30 +22,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// tableID is an index into IPTables.tables.
-type tableID int
+// TableID identifies a specific table.
+type TableID int
+// Each value identifies a specific table.
const (
- natID tableID = iota
- mangleID
- filterID
- numTables
+ NATID TableID = iota
+ MangleID
+ FilterID
+ NumTables
)
-// Table names.
-const (
- NATTable = "nat"
- MangleTable = "mangle"
- FilterTable = "filter"
-)
-
-// nameToID is immutable.
-var nameToID = map[string]tableID{
- NATTable: natID,
- MangleTable: mangleID,
- FilterTable: filterID,
-}
-
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- v4Tables: [numTables]Table{
- natID: Table{
+ v4Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -81,7 +68,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -99,7 +86,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -122,8 +109,8 @@ func DefaultTables() *IPTables {
},
},
},
- v6Tables: [numTables]Table{
- natID: Table{
+ v6Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -146,7 +133,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -164,7 +151,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -187,10 +174,10 @@ func DefaultTables() *IPTables {
},
},
},
- priorities: [NumHooks][]tableID{
- Prerouting: []tableID{mangleID, natID},
- Input: []tableID{natID, filterID},
- Output: []tableID{mangleID, natID, filterID},
+ priorities: [NumHooks][]TableID{
+ Prerouting: []TableID{MangleID, NATID},
+ Input: []TableID{NATID, FilterID},
+ Output: []TableID{MangleID, NATID, FilterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -229,26 +216,20 @@ func EmptyNATTable() Table {
}
}
-// GetTable returns a table by name.
-func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
- id, ok := nameToID[name]
- if !ok {
- return Table{}, false
- }
+// GetTable returns a table with the given id and IP version. It panics when an
+// invalid id is provided.
+func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
it.mu.RLock()
defer it.mu.RUnlock()
if ipv6 {
- return it.v6Tables[id], true
+ return it.v6Tables[id]
}
- return it.v4Tables[id], true
+ return it.v4Tables[id]
}
-// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
- id, ok := nameToID[name]
- if !ok {
- return tcpip.ErrInvalidOptionValue
- }
+// ReplaceTable replaces or inserts table by name. It panics when an invalid id
+// is provided.
+func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
for _, tableID := range priorities {
// If handlePacket already NATed the packet, we don't need to
// check the NAT table.
- if tableID == natID && pkt.NatDone {
+ if tableID == NATID && pkt.NatDone {
continue
}
var table Table
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 538c4625d..d63e9757c 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -26,13 +28,6 @@ type AcceptTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (at *AcceptTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: at.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
@@ -44,22 +39,11 @@ type DropTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (dt *DropTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: dt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
-// ErrorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const ErrorTargetName = "ERROR"
-
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
type ErrorTarget struct {
@@ -67,14 +51,6 @@ type ErrorTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (et *ErrorTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: et.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
@@ -90,14 +66,6 @@ type UserChainTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (uc *UserChainTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: uc.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
@@ -110,50 +78,39 @@ type ReturnTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *ReturnTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
-// RedirectTargetName is used to mark targets as redirect targets. Redirect
-// targets should be reached for only NAT and Mangle tables. These targets will
-// change the destination port/destination IP for packets.
-const RedirectTargetName = "REDIRECT"
-
-// RedirectTarget redirects the packet by modifying the destination port/IP.
+// RedirectTarget redirects the packet to this machine by modifying the
+// destination port/IP. Outgoing packets are redirected to the loopback device,
+// and incoming packets are redirected to the incoming interface (rather than
+// forwarded).
+//
// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
// them.
type RedirectTarget struct {
- // Addr indicates address used to redirect.
- Addr tcpip.Address
-
- // Port indicates port used to redirect.
+ // Port indicates port used to redirect. It is immutable.
Port uint16
- // NetworkProtocol is the network protocol the target is used with.
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *RedirectTarget) ID() TargetID {
- return TargetID{
- Name: RedirectTargetName,
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// implementation only works for Prerouting and calls pkt.Clone(), neither
// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ address = tcpip.Address([]byte{127, 0, 0, 1})
} else {
- rt.Addr = header.IPv6Loopback
+ address = header.IPv6Loopback
}
case Prerouting:
- rt.Addr = address
+ // No-op, as address is already set correctly.
default:
panic("redirect target is supported only on output and prerouting hooks")
}
@@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
+ netHeader := pkt.Network()
+ netHeader.SetDestinationAddress(address)
// Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(protocol, length)
- for _, v := range pkt.Data.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udpHeader.SetChecksum(0)
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- pkt.Network().SetDestinationAddress(rt.Addr)
-
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 7b3f3e88b..4b86c1be9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -37,7 +37,6 @@ import (
// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
type Hook uint
-// These values correspond to values in include/uapi/linux/netfilter.h.
const (
// Prerouting happens before a packet is routed to applications or to
// be forwarded.
@@ -86,8 +85,8 @@ type IPTables struct {
mu sync.RWMutex
// v4Tables and v6tables map tableIDs to tables. They hold builtin
// tables only, not user tables. mu must be locked for accessing.
- v4Tables [numTables]Table
- v6Tables [numTables]Table
+ v4Tables [NumTables]Table
+ v6Tables [NumTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
@@ -96,7 +95,7 @@ type IPTables struct {
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
- priorities [NumHooks][]tableID
+ priorities [NumHooks][]TableID
connections ConnTrack
@@ -104,6 +103,24 @@ type IPTables struct {
reaperDone chan struct{}
}
+// VisitTargets traverses all the targets of all tables and replaces each with
+// transform(target).
+func (it *IPTables) VisitTargets(transform func(Target) Target) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ for tid := range it.v4Tables {
+ for i, rule := range it.v4Tables[tid].Rules {
+ it.v4Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+ for tid := range it.v6Tables {
+ for i, rule := range it.v6Tables[tid].Rules {
+ it.v6Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+}
+
// A Table defines a set of chains and hooks into the network stack.
//
// It is a list of Rules, entry points (BuiltinChains), and error handlers
@@ -169,7 +186,6 @@ type IPHeaderFilter struct {
// CheckProtocol determines whether the Protocol field should be
// checked during matching.
- // TODO(gvisor.dev/issue/3549): Check this field during matching.
CheckProtocol bool
// Dst matches the destination IP address.
@@ -309,23 +325,8 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
-// A TargetID uniquely identifies a target.
-type TargetID struct {
- // Name is the target name as stored in the xt_entry_target struct.
- Name string
-
- // NetworkProtocol is the protocol to which the target applies.
- NetworkProtocol tcpip.NetworkProtocolNumber
-
- // Revision is the version of the target.
- Revision uint8
-}
-
// A Target is the interface for taking an action for a packet.
type Target interface {
- // ID uniquely identifies the Target.
- ID() TargetID
-
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index aec77610d..493e48031 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -17,12 +17,19 @@ package stack
import (
"fmt"
"sync"
+ "time"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+const (
+ // immediateDuration is a duration of zero for scheduling work that needs to
+ // be done immediately but asynchronously to avoid deadlock.
+ immediateDuration time.Duration = 0
+)
+
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
Addr tcpip.Address
@@ -242,7 +249,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
e.job.Schedule(config.RetransmitTimer)
}
- sendUnicastProbe()
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(immediateDuration)
case Failed:
e.notifyWakersLocked()
@@ -324,7 +336,12 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
e.job.Schedule(config.RetransmitTimer)
}
- sendMulticastProbe()
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(immediateDuration)
case Stale:
e.setStateLocked(Delay)
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index d297c9422..c2b763325 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -47,6 +47,12 @@ const (
entryTestNetDefaultMTU = 65536
)
+// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
+// time.
+func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
+ clock.Advance(immediateDuration)
+}
+
// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
// The UpdatedAtNanos field is ignored due to a lack of a deterministic method
// to predict the time that an event will be dispatched.
@@ -308,7 +314,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
func TestEntryUnknownToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
@@ -317,6 +323,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -354,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
func TestEntryUnknownToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
@@ -364,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Unlock()
// No probes should have been sent.
+ runImmediatelyScheduledJobs(clock)
linkRes.mu.Lock()
diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
linkRes.mu.Unlock()
@@ -488,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
func TestEntryIncompleteToReachable(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -519,6 +520,17 @@ func TestEntryIncompleteToReachable(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -552,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) {
// to Reachable.
func TestEntryAddsAndClearsWakers(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
w := sleep.Waker{}
s := sleep.Sleeper{}
@@ -561,6 +573,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
if got := e.wakers; got != nil {
t.Errorf("got e.wakers = %v, want = nil", got)
}
@@ -584,20 +614,6 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -627,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: true,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.isRouter, true; got != want {
- t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -660,6 +666,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
}
linkRes.mu.Unlock()
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -689,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
func TestEntryIncompleteToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -720,6 +733,17 @@ func TestEntryIncompleteToStale(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -830,12 +854,30 @@ func (*testLocker) Unlock() {}
func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -861,20 +903,6 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -910,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -945,6 +959,24 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -983,16 +1015,9 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1007,6 +1032,17 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
@@ -1053,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1085,6 +1110,21 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1119,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1165,6 +1184,25 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1199,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1245,6 +1262,25 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1279,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1324,6 +1340,24 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1353,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1392,6 +1408,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1430,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1455,20 +1511,6 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1507,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1546,6 +1570,28 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1584,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1619,6 +1651,24 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1657,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
func TestEntryStaleToDelay(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1689,6 +1728,21 @@ func TestEntryStaleToDelay(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1736,21 +1790,9 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleUpperLevelConfirmationLocked()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1765,8 +1807,23 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1833,28 +1890,9 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1869,8 +1907,30 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1937,6 +1997,24 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1959,22 +2037,7 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2031,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2071,6 +2115,29 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2109,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2142,6 +2197,22 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2189,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked(entryTestAddr2)
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2226,6 +2281,26 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2277,6 +2352,27 @@ func TestEntryDelayToProbe(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2289,25 +2385,19 @@ func TestEntryDelayToProbe(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
@@ -2367,6 +2457,27 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2376,25 +2487,19 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2459,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
@@ -2473,6 +2572,27 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2482,25 +2602,19 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2569,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
@@ -2583,6 +2691,27 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2592,25 +2721,20 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2696,9 +2820,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
wantProbes := []entryTestProbeInfo{
- // Probe caused by the Delay-to-Probe transition
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
@@ -2729,7 +2851,6 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2795,6 +2916,27 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2804,25 +2946,19 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2843,7 +2979,6 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -2918,6 +3053,27 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -2927,25 +3083,19 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2963,7 +3113,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -3038,6 +3187,27 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -3047,25 +3217,19 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -3083,7 +3247,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -3158,9 +3321,9 @@ func TestEntryProbeToFailed(t *testing.T) {
e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
{
wantProbes := []entryTestProbeInfo{
- // Caused by the Unknown-to-Incomplete transition.
{
RemoteAddress: entryTestAddr1,
LocalAddress: entryTestAddr2,
@@ -3283,6 +3446,26 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
e.mu.Lock()
e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -3293,33 +3476,28 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
clock.Advance(waitFor)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The next three probe are sent in Probe.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 17f2e6b46..60c81a3aa 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -348,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address
return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
+func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ return true
+ }
+
+ return false
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
@@ -555,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
}
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
defer r.Release()
- r.RemoteLinkAddress = remotelinkAddr
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -594,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
if local == "" {
local = n.LinkEndpoint.LinkAddress()
}
+ pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
@@ -669,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// Found a NIC.
- n := r.nic
+ n := r.localAddressNIC
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
- r.RemoteLinkAddress = remote
+ pkt.NICID = n.ID()
r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
addressEndpoint.DecRef()
r.Release()
return
@@ -735,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
+func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -747,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- n.stack.demux.deliverRawPacket(r, protocol, pkt)
+ n.stack.demux.deliverRawPacket(protocol, pkt)
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
@@ -776,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
return TransportPacketHandled
}
- id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
+ netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers()))
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
+ id := TransportEndpointID{
+ LocalPort: dstPort,
+ LocalAddress: dst,
+ RemotePort: srcPort,
+ RemoteAddress: src,
+ }
+ if n.stack.demux.deliverPacket(protocol, pkt, id) {
return TransportPacketHandled
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, pkt) {
+ if state.defaultHandler(id, pkt) {
return TransportPacketHandled
}
}
@@ -791,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet so
// give the protocol specific error handler a chance to handle it.
// If it doesn't handle it then we should do so.
- switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res {
+ switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res {
case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
return TransportPacketHandled
@@ -895,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed) unless the NIC is in spoofing mode, or temporary.
func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
n.mu.RLock()
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 4af04846f..5b5c58afb 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
}
// HandlePacket implements NetworkEndpoint.HandlePacket.
-func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
-}
+func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {}
// Close implements NetworkEndpoint.Close.
func (e *testIPv6Endpoint) Close() {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 7f54a6de8..664cc6fa0 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -112,6 +112,16 @@ type PacketBuffer struct {
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
+
+ // NICID is the ID of the interface the network packet was received at.
+ NICID tcpip.NICID
+
+ // RXTransportChecksumValidated indicates that transport checksum verification
+ // may be safely skipped.
+ RXTransportChecksumValidated bool
+
+ // NetworkPacketInfo holds an incoming packet's network-layer information.
+ NetworkPacketInfo NetworkPacketInfo
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// Clone should be called in such cases so that no modifications is done to
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
- TransportProtocolNumber: pk.TransportProtocolNumber,
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
+ PktType: pk.PktType,
+ NICID: pk.NICID,
+ RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
+ NetworkPacketInfo: pk.NetworkPacketInfo,
}
- return newPk
+}
+
+// SourceLinkAddress returns the source link address of the packet.
+func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress {
+ link := pk.LinkHeader().View()
+
+ if link.IsEmpty() {
+ return ""
+ }
+
+ return header.Ethernet(link).SourceAddress()
}
// Network returns the network header as a header.Network.
@@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
+// CloneToInbound makes a shallow copy of the packet buffer to be used as an
+// inbound packet.
+//
+// See PacketBuffer.Data for details about how a packet buffer holds an inbound
+// packet.
+func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
+ return NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
+ })
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index f838eda8d..5d364a2b0 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
} else if _, err := p.route.Resolve(nil); err != nil {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
+ p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 203f3b51f..b8f333057 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -63,17 +63,28 @@ const (
ControlUnknown
)
+// NetworkPacketInfo holds information about a network layer packet.
+type NetworkPacketInfo struct {
+ // RemoteAddressBroadcast is true if the packet's remote address is a
+ // broadcast address.
+ RemoteAddressBroadcast bool
+
+ // LocalAddressBroadcast is true if the packet's local address is a broadcast
+ // address.
+ LocalAddressBroadcast bool
+}
+
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
// UniqueID returns an unique ID for this transport endpoint.
UniqueID() uint64
- // HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint. It sets pkt.TransportHeader.
+ // HandlePacket is called by the stack when new packets arrive to this
+ // transport endpoint. It sets the packet buffer's transport header.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(TransportEndpointID, *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
@@ -105,8 +116,8 @@ type RawTransportEndpoint interface {
// this transport endpoint. The packet contains all data from the link
// layer up.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(*PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -172,9 +183,9 @@ type TransportProtocol interface {
// protocol that don't match any existing endpoint. For example,
// it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // HandleUnknownDestinationPacket takes ownership of the packet if it handles
// the issue.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
+ HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -227,8 +238,8 @@ type TransportDispatcher interface {
//
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
- // DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition
+ // DeliverTransportPacket takes ownership of the packet.
+ DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface {
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
+ // Subnet returns the subnet of the endpoint's address.
+ Subnet() tcpip.Subnet
+
// IsAssigned returns whether or not the endpoint is considered bound
// to its NetworkEndpoint.
IsAssigned(allowExpired bool) bool
@@ -547,7 +561,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ HandlePacket(pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index b76e2d37b..15ff437c7 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -45,11 +47,16 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
- // nic is the NIC the route goes through.
- nic *NIC
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
- // addressEndpoint is the local address this route is associated with.
- addressEndpoint AssignableAddressEndpoint
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
// linkCache is set if link address resolution is enabled for this protocol on
// the route's NIC.
@@ -60,51 +67,144 @@ type Route struct {
linkRes LinkAddressResolver
}
+// constructAndValidateRoute validates and initializes a route. It takes
+// ownership of the provided local address.
+//
+// Returns an empty route if validation fails.
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
+ addrWithPrefix := addressEndpoint.AddressWithPrefix()
+
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ addressEndpoint.DecRef()
+ return Route{}
+ }
+
+ // If no remote address is provided, use the local address.
+ if len(remoteAddr) == 0 {
+ remoteAddr = addrWithPrefix.Address
+ }
+
+ r := makeRoute(
+ netProto,
+ addrWithPrefix.Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ addressEndpoint,
+ handleLocal,
+ multicastLoop,
+ )
+
+ // If the route requires us to send a packet through some gateway, do not
+ // broadcast it.
+ if len(gateway) > 0 {
+ r.NextHop = gateway
+ } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ }
+
+ return r
+}
+
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+ if localAddressNIC.stack != outgoingNIC.stack {
+ panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
+ }
+
loop := PacketOut
- if handleLocal && localAddr != "" && remoteAddr == localAddr {
- loop = PacketLoop
- } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
- loop |= PacketLoop
- } else if remoteAddr == header.IPv4Broadcast {
- loop |= PacketLoop
+
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if !outgoingNIC.IsLoopback() {
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
+ } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
+ loop |= PacketLoop
+ }
}
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- addressEndpoint: addressEndpoint,
- nic: nic,
- Loop: loop,
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ localAddressEndpoint: localAddressEndpoint,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
- if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = r.nic.stack
+ r.linkCache = r.outgoingNIC.stack
}
}
return r
}
+// makeLocalRoute initializes a new local route. It takes ownership of the
+// provided AssignableAddressEndpoint.
+//
+// A local route is a route to a destination that is local to the stack.
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+ loop := PacketLoop
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if outgoingNIC.IsLoopback() {
+ loop = PacketOut
+ }
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+// PopulatePacketInfo populates a packet buffer's packet information fields.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) {
+ if r.local() {
+ pkt.RXTransportChecksumValidated = true
+ }
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+}
+
+// networkPacketInfo returns the network packet information of the route.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) networkPacketInfo() NetworkPacketInfo {
+ return NetworkPacketInfo{
+ RemoteAddressBroadcast: r.IsOutboundBroadcast(),
+ LocalAddressBroadcast: r.isInboundBroadcast(),
+ }
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.nic.ID()
+ return r.outgoingNIC.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.nic.stack.Stats()
+ return r.outgoingNIC.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
}
-// Capabilities returns the link-layer capabilities of the route.
-func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint.Capabilities()
+// RequiresTXTransportChecksum returns false if the route does not require
+// transport checksums to be populated.
+func (r *Route) RequiresTXTransportChecksum() bool {
+ if r.local() {
+ return false
+ }
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0
+}
+
+// HasSoftwareGSOCapability returns true if the route supports software GSO.
+func (r *Route) HasSoftwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+}
+
+// HasHardwareGSOCapability returns true if the route supports hardware GSO.
+func (r *Route) HasHardwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+}
+
+// HasSaveRestoreCapability returns true if the route supports save/restore.
+func (r *Route) HasSaveRestoreCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0
+}
+
+// HasDisconncetOkCapability returns true if the route supports disconnecting.
+func (r *Route) HasDisconncetOkCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
+ // If specified, the local address used for link address resolution must be an
+ // address on the outgoing interface.
+ var linkAddressResolutionRequestLocalAddr tcpip.Address
+ if r.localAddressNIC == r.outgoingNIC {
+ linkAddressResolutionRequestLocalAddr = r.LocalAddress
+ }
+
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
if err != nil {
return ch, err
}
@@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
if err != nil {
return ch, err
}
@@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
neigh.removeWaker(nextAddr, waker)
return
}
- r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
+ r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
+}
+
+// local returns true if the route is a local route.
+func (r *Route) local() bool {
+ return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
-// the link address before the this route can be written to.
+// the link address before the route can be written to.
//
-// The NIC r uses must not be locked.
+// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if r.nic.neigh != nil {
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
+ if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ return false
}
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
+
+ return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil
+}
+
+func (r *Route) isValidForOutgoing() bool {
+ if !r.outgoingNIC.Enabled() {
+ return false
+ }
+
+ if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ return false
+ }
+
+ // If the source NIC and outgoing NIC are different, make sure the stack has
+ // forwarding enabled, or the packet will be handled locally.
+ if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) {
+ return false
+ }
+
+ return true
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.nic.getNetworkEndpoint(r.NetProto).MTU()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.addressEndpoint != nil {
- r.addressEndpoint.DecRef()
- r.addressEndpoint = nil
+ if r.localAddressEndpoint != nil {
+ r.localAddressEndpoint.DecRef()
+ r.localAddressEndpoint = nil
}
}
// Clone clones the route.
func (r *Route) Clone() Route {
- if r.addressEndpoint != nil {
- _ = r.addressEndpoint.IncRef()
+ if r.localAddressEndpoint != nil {
+ if !r.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
+ }
}
return *r
}
@@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route {
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.nic.stack
+ return r.outgoingNIC.stack
}
func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
@@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
+ subnet := r.localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool {
return r.isV4Broadcast(r.RemoteAddress)
}
-// IsInboundBroadcast returns true if the route is for an inbound broadcast
+// isInboundBroadcast returns true if the route is for an inbound broadcast
// packet.
-func (r *Route) IsInboundBroadcast() bool {
+func (r *Route) isInboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.LocalAddress)
}
@@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool {
// ReverseRoute returns new route with given source and destination address.
func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- addressEndpoint: r.addressEndpoint,
- nic: r.nic,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
+ NetProto: r.NetProto,
+ LocalAddress: dst,
+ LocalLinkAddress: r.RemoteLinkAddress,
+ RemoteAddress: src,
+ RemoteLinkAddress: r.LocalLinkAddress,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ localAddressEndpoint: r.localAddressEndpoint,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e8f1c110e..a23fb97ff 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"bytes"
"encoding/binary"
+ "fmt"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -52,7 +53,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -518,6 +519,10 @@ type Options struct {
//
// RandSource must be thread-safe.
RandSource mathrand.Source
+
+ // IPTables are the initial iptables rules. If nil, iptables will allow
+ // all traffic.
+ IPTables *IPTables
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -620,6 +625,10 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
+ if opts.IPTables == nil {
+ opts.IPTables = DefaultTables()
+ }
+
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
@@ -633,7 +642,7 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
+ tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
nudConfigs: opts.NUDConfigs,
@@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -1194,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
+// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route
+// from the specified NIC.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
+ if localAddressEndpoint == nil {
+ return Route{}, false
+ }
+
+ var outgoingNIC *NIC
+ // Prefer a local route to the same interface as the local address.
+ if localAddressNIC.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = localAddressNIC
+ }
+
+ // If the remote address isn't owned by the local address's NIC, check all
+ // NICs.
+ if outgoingNIC == nil {
+ for _, nic := range s.nics {
+ if nic.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = nic
+ break
+ }
+ }
+ }
+
+ // If the remote address is not owned by the stack, we can't return a local
+ // route.
+ if outgoingNIC == nil {
+ localAddressEndpoint.DecRef()
+ return Route{}, false
+ }
+
+ r := makeLocalRoute(
+ netProto,
+ localAddressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ localAddressEndpoint,
+ )
+
+ if r.IsOutboundBroadcast() {
+ r.Release()
+ return Route{}, false
+ }
+
+ return r, true
+}
+
+// findLocalRouteRLocked returns a local route.
+//
+// A local route is a route to some remote address which the stack owns. That
+// is, a local route is a route where packets never have to leave the stack.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ if len(localAddr) == 0 {
+ localAddr = remoteAddr
+ }
+
+ if localAddressNICID == 0 {
+ for _, localAddressNIC := range s.nics {
+ if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
+ return r, true
+ }
+ }
+
+ return Route{}, false
+ }
+
+ if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
+ return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
+ }
+
+ return Route{}, false
+}
+
// FindRoute creates a route to the given destination address, leaving through
-// the given nic and local address (if provided).
+// the given NIC and local address (if provided).
+//
+// If a NIC is not specified, the returned route will leave through the same
+// NIC as the NIC that has the local address assigned when forwarding is
+// disabled. If forwarding is enabled and the NIC is unspecified, the route may
+// leave through any interface unless the route is link-local.
+//
+// If no local address is provided, the stack will select a local address. If no
+// remote address is provided, the stack wil use a remote address equal to the
+// local address.
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
+ isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
+ needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
+
+ if s.handleLocal && !isMulticast && !isLocalBroadcast {
+ if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ return r, nil
+ }
+ }
+
+ // If the interface is specified and we do not need a route, return a route
+ // through the interface if the interface is valid and enabled.
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.Enabled() {
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
+ return makeRoute(
+ netProto,
+ addressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ nic, /* outboundNIC */
+ nic, /* localAddressNIC*/
+ addressEndpoint,
+ s.handleLocal,
+ multicastLoop,
+ ), nil
}
}
- } else {
- for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
- continue
+
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+
+ // Find a route to the remote with the route table.
+ var chosenRoute tcpip.Route
+ for _, route := range s.routeTable {
+ if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
+ continue
+ }
+
+ nic, ok := s.nics[route.NIC]
+ if !ok || !nic.Enabled() {
+ continue
+ }
+
+ if id == 0 || id == route.NIC {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = route.Gateway
+ }
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
+ if r == (Route{}) {
+ panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ }
+ return r, nil
}
- if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
- if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if len(remoteAddr) == 0 {
- // If no remote address was provided, then the route
- // provided will refer to the link local address.
- remoteAddr = addressEndpoint.AddressWithPrefix().Address
- }
+ }
+
+ // If the stack has forwarding enabled and we haven't found a valid route to
+ // the remote address yet, keep track of the first valid route. We keep
+ // iterating because we prefer routes that let us use a local address that
+ // is assigned to the outgoing interface. There is no requirement to do this
+ // from any RFC but simply a choice made to better follow a strong host
+ // model which the netstack follows at the time of writing.
+ if canForward && chosenRoute == (tcpip.Route{}) {
+ chosenRoute = route
+ }
+ }
+
+ if chosenRoute != (tcpip.Route{}) {
+ // At this point we know the stack has forwarding enabled since chosenRoute is
+ // only set when forwarding is enabled.
+ nic, ok := s.nics[chosenRoute.NIC]
+ if !ok {
+ // If the route's NIC was invalid, we should not have chosen the route.
+ panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC))
+ }
+
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = chosenRoute.Gateway
+ }
- r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
- if len(route.Gateway) > 0 {
- if needRoute {
- r.NextHop = route.Gateway
- }
- } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ // Use the specified NIC to get the local address endpoint.
+ if id != 0 {
+ if aNIC, ok := s.nics[id]; ok {
+ if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ return r, nil
}
+ }
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+ }
+
+ if id == 0 {
+ // If an interface is not specified, try to find a NIC that holds the local
+ // address endpoint to construct a route.
+ for _, aNIC := range s.nics {
+ addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto)
+ if addressEndpoint == nil {
+ continue
+ }
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
return r, nil
}
}
}
}
- if !needRoute {
- return Route{}, tcpip.ErrNetworkUnreachable
+ if needRoute {
+ return Route{}, tcpip.ErrNoRoute
}
-
- return Route{}, tcpip.ErrNoRoute
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1457,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
// FindTransportEndpoint finds an endpoint that most closely matches the provided
// id. If no endpoint is found it returns nil.
-func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
- return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, nicID)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
@@ -1910,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
return tcpip.NewJob(s.clock, l, f)
}
+
+// ParseResult indicates the result of a parsing attempt.
+type ParseResult int
+
+const (
+ // ParsedOK indicates that a packet was successfully parsed.
+ ParsedOK ParseResult = iota
+
+ // UnknownNetworkProtocol indicates that the network protocol is unknown.
+ UnknownNetworkProtocol
+
+ // NetworkLayerParseError indicates that the network packet was not
+ // successfully parsed.
+ NetworkLayerParseError
+
+ // UnknownTransportProtocol indicates that the transport protocol is unknown.
+ UnknownTransportProtocol
+
+ // TransportLayerParseError indicates that the transport packet was not
+ // successfully parsed.
+ TransportLayerParseError
+)
+
+// ParsePacketBuffer parses the provided packet buffer.
+func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return UnknownNetworkProtocol
+ }
+
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ return NetworkLayerParseError
+ }
+ if !hasTransportHdr {
+ return ParsedOK
+ }
+
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ return ParsedOK
+ }
+
+ pkt.TransportProtocolNumber = transProtoNum
+ // Parse the transport header if present.
+ state, ok := s.transportProtocols[transProtoNum]
+ if !ok {
+ return UnknownTransportProtocol
+ }
+
+ if !state.proto.Parse(pkt) {
+ return TransportLayerParseError
+ }
+
+ return ParsedOK
+}
+
+// networkProtocolNumbers returns the network protocol numbers the stack is
+// configured with.
+func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
+ protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols))
+ for p := range s.networkProtocols {
+ protos = append(protos, p)
+ }
+ return protos
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 4eed4ced4..dedfdd435 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
"testing"
"time"
@@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
+ netHdr := pkt.NetworkHeader().View()
+ f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
// Handle control packets.
- if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
@@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ pkt.NetworkProtocolNumber = fakeNetNumber
hdr[dstAddrOffset] = r.RemoteAddress[0]
hdr[srcAddrOffset] = r.LocalAddress[0]
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- f.HandlePacket(r, pkt)
+ pkt := pkt.Clone()
+ r.PopulatePacketInfo(pkt)
+ f.HandlePacket(pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
if !ok {
return 0, false, false
}
+ pkt.NetworkProtocolNumber = fakeNetNumber
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
@@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) {
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
+// TestExternalSendWithHandleLocal tests that the stack creates a non-local
+// route when spoofing or promiscuous mode are enabled.
+//
+// This test makes sure that packets are transmitted from the stack.
+func TestExternalSendWithHandleLocal(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+
+ localAddr = tcpip.Address("\x01")
+ dstAddr = tcpip.Address("\x03")
+ )
+
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ configureStack func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Default",
+ configureStack: func(*testing.T, *stack.Stack) {},
+ },
+ {
+ name: "Spoofing",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetSpoofing(nicID, true); err != nil {
+ t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ {
+ name: "Promiscuous",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
+ t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, handleLocal := range []bool{true, false} {
+ t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ HandleLocal: handleLocal,
+ })
+
+ ep := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
+
+ test.configureStack(t, s)
+
+ r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err)
+ }
+ defer r.Release()
+
+ if r.LocalAddress != localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+
+ if n := ep.Drain(); n != 0 {
+ t.Fatalf("got ep.Drain() = %d, want = 0", n)
+ }
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: fakeTransNumber,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.NewView(10).ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, _, _): %s", err)
+ }
+ if n := ep.Drain(); n != 1 {
+ t.Fatalf("got ep.Drain() = %d, want = 1", n)
+ }
+ })
+ }
+ })
+ }
+}
+
func TestSpoofingWithAddress(t *testing.T) {
localAddr := tcpip.Address("\x01")
nonExistentLocalAddr := tcpip.Address("\x02")
@@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
RemoteAddress: ipv4SubnetBcast,
RemoteLinkAddress: header.EthernetBroadcastAddress,
NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
+ Loop: stack.PacketOut | stack.PacketLoop,
},
},
// Broadcast to a locally attached /31 subnet does not populate the
@@ -3756,3 +3862,369 @@ func TestRemoveRoutes(t *testing.T) {
}
}
}
+
+func TestFindRouteWithForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Addr = tcpip.Address("\x01")
+ nic2Addr = tcpip.Address("\x02")
+ remoteAddr = tcpip.Address("\x03")
+ )
+
+ type netCfg struct {
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1Addr tcpip.Address
+ nic2Addr tcpip.Address
+ remoteAddr tcpip.Address
+ }
+
+ fakeNetCfg := netCfg{
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1Addr: nic1Addr,
+ nic2Addr: nic2Addr,
+ remoteAddr: remoteAddr,
+ }
+
+ globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
+ globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
+
+ ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: llAddr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: globalIPv6Addr1,
+ }
+ ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: llAddr1,
+ remoteAddr: llAddr2,
+ }
+ ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ }
+
+ tests := []struct {
+ name string
+
+ netCfg netCfg
+ forwardingEnabled bool
+
+ addrNIC tcpip.NICID
+ localAddr tcpip.Address
+
+ findRouteErr *tcpip.Error
+ dependentOnForwarding bool
+ }{
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ }
+
+ if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ }
+
+ if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
+
+ r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if err != test.findRouteErr {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
+ }
+ defer r.Release()
+
+ if test.findRouteErr != nil {
+ return
+ }
+
+ if r.LocalAddress != test.localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr)
+ }
+ if r.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Sending a packet should always go through NIC2 since we only install a
+ // route to test.netCfg.remoteAddr through NIC2.
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := send(r, data); err != nil {
+ t.Fatalf("send(_, _): %s", err)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not sent through ep2")
+ }
+ if pkt.Route.LocalAddress != test.localAddr {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ }
+ if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if !test.forwardingEnabled || !test.dependentOnForwarding {
+ return
+ }
+
+ // Disabling forwarding when the route is dependent on forwarding being
+ // enabled should make the route invalid.
+ if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
+ t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
+ }
+ if err := send(r, data); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ if n := ep2.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep2", n)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 35e5b1a2e..f183ec6e4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isInboundMulticastOrBroadcast(r) {
- mpep.handlePacketAll(r, id, pkt)
+ if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
+ mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
- queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return
}
- transEP.HandlePacket(r, id, pkt)
+ transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
type transportDemuxer struct {
+ stack *Stack
+
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
queuedProtocols map[protocolIDs]queuedTransportProtocol
@@ -262,11 +264,12 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
+ QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{
+ stack: stack,
protocol: make(map[protocolIDs]*transportEndpoints),
queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
}
@@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ queuedProtocol.QueuePacket(endpoint, id, pkt.Clone())
} else {
- endpoint.HandlePacket(r, id, pkt.Clone())
+ endpoint.HandlePacket(id, pkt.Clone())
}
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ queuedProtocol.QueuePacket(endpoint, id, pkt)
} else {
- endpoint.HandlePacket(r, id, pkt)
+ endpoint.HandlePacket(id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
// handlePacket takes ownership of pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
- ep.handlePacket(r, id, pkt.Clone())
+ ep.handlePacket(id, pkt.Clone())
}
- destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
}
@@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// destination address, then do nothing further and instruct the caller to do
// the same. The network layer handles address validation for specified source
// addresses.
- if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
+ if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
// TCP can only be used to communicate between a single source and a
- // single destination; the addresses must be unicast.
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ // single destination; the addresses must be unicast.e
+ d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
return true
}
@@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
}
return false
}
- ep.handlePacket(r, id, pkt)
+ ep.handlePacket(id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
@@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, pkt)
+ rawEP.HandlePacket(pkt.Clone())
foundRaw = true
}
eps.mu.RUnlock()
@@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
}
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
-func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
@@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[nicID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isInboundMulticastOrBroadcast(r *Route) bool {
- return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
+func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
+ return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
}
func isSpecified(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 6b8071467..c457b67a2 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
- if f.acceptQueue != nil {
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- ID: f.ID,
- NetProto: f.NetProto,
- },
- proto: f.proto,
- peerAddr: r.RemoteAddress,
- route: r.Clone(),
- })
+ if f.acceptQueue == nil {
+ return
}
+
+ netHdr := pkt.NetworkHeader().View()
+ route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ route.ResolveWith(pkt.SourceLinkAddress())
+
+ f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
+ proto: f.proto,
+ peerAddr: route.RemoteAddress,
+ route: route,
+ })
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
@@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 34aab32d0..9b0f3b675 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -10,6 +10,7 @@ go_test(
"link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
+ "route_test.go",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 0dcef7b04..bf7594268 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -33,11 +33,6 @@ import (
func TestForwarding(t *testing.T) {
const (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
- routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-
host1NICID = 1
routerNICID1 = 2
routerNICID2 = 3
@@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) {
}
},
},
+ {
+ name: "IPv4 host2 server with routerNIC1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ {
+ name: "IPv6 routerNIC2 server with host1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
}
for _, test := range tests {
@@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) {
routerStack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr)
- routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr)
+ host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
+ routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
@@ -321,12 +348,8 @@ func TestForwarding(t *testing.T) {
if err == tcpip.ErrNoLinkAddress {
// Wait for link resolution to complete.
<-ch
-
n, _, err = ep.Write(dataPayload, wOpts)
- } else if err != nil {
- t.Fatalf("ep.Write(_, _): %s", err)
}
-
if err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
}
@@ -343,7 +366,6 @@ func TestForwarding(t *testing.T) {
// Wait for the endpoint to be readable.
<-ch
-
var addr tcpip.FullAddress
v, _, err := ep.Read(&addr)
if err != nil {
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 6ddcda70c..fe7c1bb3d 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -32,32 +32,36 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-var (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+const (
+ linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+)
- host1IPv4Addr = tcpip.ProtocolAddress{
+var (
+ ipv4Addr1 = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
- host2IPv4Addr = tcpip.ProtocolAddress{
+ ipv4Addr2 = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 8,
},
}
- host1IPv6Addr = tcpip.ProtocolAddress{
+ ipv6Addr1 = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::1").To16()),
PrefixLen: 64,
},
}
- host2IPv6Addr = tcpip.ProtocolAddress{
+ ipv6Addr2 = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::2").To16()),
@@ -89,7 +93,7 @@ func TestPing(t *testing.T) {
name: "IPv4 Ping",
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
- remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
icmpBuf: func(t *testing.T) buffer.View {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
@@ -104,7 +108,7 @@ func TestPing(t *testing.T) {
name: "IPv6 Ping",
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
- remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
icmpBuf: func(t *testing.T) buffer.View {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
@@ -127,7 +131,7 @@ func TestPing(t *testing.T) {
host1Stack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr)
+ host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2)
if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
@@ -143,36 +147,36 @@ func TestPing(t *testing.T) {
t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
tcpip.Route{
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv4Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
tcpip.Route{
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv6Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
})
host2Stack.SetRouteTable([]tcpip.Route{
tcpip.Route{
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv4Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
tcpip.Route{
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv6Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
})
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index e8caf09ba..421da1add 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -204,7 +204,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
},
})
- wq := waiter.Queue{}
+ var wq waiter.Queue
rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
if err != nil {
t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index f1028823b..cdf0459e3 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -409,7 +409,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- wq := waiter.Queue{}
+ var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
if err != nil {
t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
@@ -447,8 +447,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
)
- data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
-
tests := []struct {
name string
broadcastAddr tcpip.Address
@@ -492,16 +490,22 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
},
})
+ type endpointAndWaiter struct {
+ ep tcpip.Endpoint
+ ch chan struct{}
+ }
+ var eps []endpointAndWaiter
// We create endpoints that bind to both the wildcard address and the
// broadcast address to make sure both of these types of "broadcast
// interested" endpoints receive broadcast packets.
- wq := waiter.Queue{}
- var eps []tcpip.Endpoint
for _, bindWildcard := range []bool{false, true} {
// Create multiple endpoints for each type of "broadcast interested"
// endpoint so we can test that all endpoints receive the broadcast
// packet.
for i := 0; i < 2; i++ {
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
@@ -528,7 +532,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
}
}
- eps = append(eps, ep)
+ eps = append(eps, endpointAndWaiter{ep: ep, ch: ch})
}
}
@@ -539,14 +543,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
Port: localPort,
},
}
- if n, _, err := wep.Write(data, writeOpts); err != nil {
+ data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
+ if n, _, err := wep.ep.Write(data, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
}
for j, rep := range eps {
- if gotPayload, _, err := rep.Read(nil); err != nil {
+ // Wait for the endpoint to become readable.
+ <-rep.ch
+
+ if gotPayload, _, err := rep.ep.Read(nil); err != nil {
t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
} else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
new file mode 100644
index 000000000..02fc47015
--- /dev/null
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -0,0 +1,388 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TestLocalPing tests pinging a remote that is local the stack.
+//
+// This tests that a local route is created and packets do not leave the stack.
+func TestLocalPing(t *testing.T) {
+ const (
+ nicID = 1
+ ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+
+ // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
+ // request/reply packets.
+ icmpDataOffset = 8
+ )
+
+ channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
+ channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
+ channelEP := e.(*channel.Endpoint)
+ if n := channelEP.Drain(); n != 0 {
+ t.Fatalf("got channelEP.Drain() = %d, want = 0", n)
+ }
+ }
+
+ ipv4ICMPBuf := func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv4Echo)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ }
+
+ ipv6ICMPBuf := func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9}
+ hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv6EchoRequest)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ }
+
+ tests := []struct {
+ name string
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+ linkEndpoint func() stack.LinkEndpoint
+ localAddr tcpip.Address
+ icmpBuf func(*testing.T) buffer.View
+ expectedConnectErr *tcpip.Error
+ checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
+ }{
+ {
+ name: "IPv4 loopback",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ localAddr: ipv4Loopback,
+ icmpBuf: ipv4ICMPBuf,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv6 loopback",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ localAddr: header.IPv6Loopback,
+ icmpBuf: ipv6ICMPBuf,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv4 non-loopback",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: channelEP,
+ localAddr: ipv4Addr.Address,
+ icmpBuf: ipv4ICMPBuf,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv6 non-loopback",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: channelEP,
+ localAddr: ipv6Addr.Address,
+ icmpBuf: ipv6ICMPBuf,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv4 loopback without local address",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ icmpBuf: ipv4ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv6 loopback without local address",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ icmpBuf: ipv6ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv4 non-loopback without local address",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: channelEP,
+ icmpBuf: ipv4ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv6 non-loopback without local address",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: channelEP,
+ icmpBuf: ipv6ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ HandleLocal: true,
+ })
+ e := test.linkEndpoint()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if len(test.localAddr) != 0 {
+ if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ }
+ }
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
+
+ connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ if err := ep.Connect(connAddr); err != test.expectedConnectErr {
+ t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ }
+
+ if test.expectedConnectErr != nil {
+ return
+ }
+
+ payload := tcpip.SlicePayload(test.icmpBuf(t))
+ var wOpts tcpip.WriteOptions
+ if n, _, err := ep.Write(payload, wOpts); err != nil {
+ t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
+ } else if n != int64(len(payload)) {
+ t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
+ }
+
+ // Wait for the endpoint to become readable.
+ <-ch
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+ if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != test.localAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr)
+ }
+
+ test.checkLinkEndpoint(t, e)
+ })
+ }
+}
+
+// TestLocalUDP tests sending UDP packets between two endpoints that are local
+// to the stack.
+//
+// This tests that that packets never leave the stack and the addresses
+// used when sending a packet.
+func TestLocalUDP(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ tests := []struct {
+ name string
+ canBePrimaryAddr tcpip.ProtocolAddress
+ firstPrimaryAddr tcpip.ProtocolAddress
+ }{
+ {
+ name: "IPv4",
+ canBePrimaryAddr: ipv4Addr1,
+ firstPrimaryAddr: ipv4Addr2,
+ },
+ {
+ name: "IPv6",
+ canBePrimaryAddr: ipv6Addr1,
+ firstPrimaryAddr: ipv6Addr2,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ addAddress bool
+ expectedWriteErr *tcpip.Error
+ }{
+ {
+ name: "Unassigned local address",
+ addAddress: false,
+ expectedWriteErr: tcpip.ErrNoRoute,
+ },
+ {
+ name: "Assigned local address",
+ addAddress: true,
+ expectedWriteErr: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ HandleLocal: true,
+ }
+
+ s := stack.New(stackOpts)
+ ep := channel.New(1, header.IPv6MinimumMTU, "")
+
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if subTest.addAddress {
+ if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ }
+
+ var serverWQ waiter.Queue
+ serverWE, serverCH := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverWE, waiter.EventIn)
+ server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
+ }
+ defer server.Close()
+
+ bindAddr := tcpip.FullAddress{Port: 80}
+ if err := server.Bind(bindAddr); err != nil {
+ t.Fatalf("server.Bind(%#v): %s", bindAddr, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventIn)
+ client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
+ }
+ defer client.Close()
+
+ serverAddr := tcpip.FullAddress{
+ Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
+ Port: 80,
+ }
+
+ clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ {
+ wOpts := tcpip.WriteOptions{
+ To: &serverAddr,
+ }
+ if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
+ } else if subTest.expectedWriteErr != nil {
+ // Nothing else to test if we expected not to be able to send the
+ // UDP packet.
+ return
+ } else if n != int64(len(clientPayload)) {
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload))
+ }
+ }
+
+ // Wait for the server endpoint to become readable.
+ <-serverCH
+
+ var clientAddr tcpip.FullAddress
+ if v, _, err := server.Read(&clientAddr); err != nil {
+ t.Fatalf("server.Read(_): %s", err)
+ } else {
+ if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" {
+ t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
+ }
+ if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address {
+ t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ {
+ wOpts := tcpip.WriteOptions{
+ To: &clientAddr,
+ }
+ if n, _, err := server.Write(serverPayload, wOpts); err != nil {
+ t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
+ } else if n != int64(len(serverPayload)) {
+ t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload))
+ }
+ }
+
+ // Wait for the client endpoint to become readable.
+ <-clientCH
+
+ var gotServerAddr tcpip.FullAddress
+ if v, _, err := client.Read(&gotServerAddr); err != nil {
+ t.Fatalf("client.Read(_): %s", err)
+ } else {
+ if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" {
+ t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
+ }
+ if gotServerAddr.Addr != serverAddr.Addr {
+ t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index a17234946..763cd8f84 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -755,7 +755,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
@@ -800,7 +800,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Push new packet into receive list and increment the buffer size.
packet := &icmpPacket{
senderAddress: tcpip.FullAddress{
- NIC: r.NICID(),
+ NIC: pkt.NICID,
Addr: id.RemoteAddress,
},
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 87d510f96..3820e5dc7 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 79f688129..7b6a87ba9 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -646,7 +646,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -671,14 +671,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
return
}
+ remoteAddr := pkt.Network().SourceAddress()
+
if e.bound {
// If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != route.NICID() {
+ if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
return
}
// If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != route.RemoteAddress {
+ if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
return
}
@@ -686,7 +688,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
// If connected, only accept packets from the remote address we
// connected to.
- if e.connected && e.route.RemoteAddress != route.RemoteAddress {
+ if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
return
}
@@ -696,8 +698,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
// Push new packet into receive list and increment the buffer size.
packet := &rawPacket{
senderAddr: tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress,
+ NIC: pkt.NICID,
+ Addr: remoteAddr,
},
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6b3238d6b..47982ca41 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
- netProto = s.route.NetProto
+ netProto = s.netProto
}
+
+ route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return nil, err
+ }
+ route.ResolveWith(s.remoteLinkAddr)
+
n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6Only
n.ID = s.id
- n.boundNICID = s.route.NICID()
- n.route = s.route.Clone()
- n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.boundNICID = s.nicID
+ n.route = route
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto}
n.rcvBufSize = int(l.rcvWnd)
n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
@@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- return n
+ return n, nil
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ if err != nil {
+ return nil, err
+ }
// Lock the endpoint before registering to ensure that no out of
// band changes are possible due to incoming packets etc till
@@ -467,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
-func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
e.rcvListMu.Unlock()
@@ -477,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
switch {
@@ -492,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
- return
+ return nil
}
ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
} else {
// If cookies are in use but the endpoint accept queue
// is full then drop the syn.
@@ -506,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Send SYN without window scaling because we currently
// don't encode this information in the cookie.
//
@@ -523,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TS: opts.TS,
TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: calculateAdvertisedMSS(e.userMSS, s.route),
+ MSS: calculateAdvertisedMSS(e.userMSS, route),
}
- e.sendSynTCP(&s.route, tcpFields{
+ fields := tcpFields{
id: s.id,
ttl: e.ttl,
tos: e.sendTOS,
@@ -533,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
seq: cookie,
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
- }, synOpts)
+ }
+ if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ return err
+ }
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ return nil
}
case (s.flags & header.TCPFlagAck) != 0:
@@ -547,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
if !ctx.synRcvdCount.synCookiesInUse() {
@@ -566,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
iss := s.ackNumber - 1
@@ -587,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
@@ -608,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ if err != nil {
+ return err
+ }
n.mu.Lock()
@@ -622,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return nil
}
// Register new endpoint so that packets are routed to it.
@@ -632,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return err
}
n.isRegistered = true
@@ -670,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
+ return nil
+
+ default:
+ return nil
}
}
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
-func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.v6only
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
@@ -714,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
- return nil
+ return
}
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
s := e.segmentQueue.dequeue()
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
close(e.drainDone)
@@ -738,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
break
}
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 0aaef495d..2facbebec 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
MSS: amss,
}
if ttl == 0 {
- ttl = s.route.DefaultTTL()
+ ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -496,7 +496,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
// Wait for notification.
- index, _ = s.Fetch(true)
+ h.ep.mu.Unlock()
+ index, _ = s.Fetch(true /* block */)
+ h.ep.mu.Lock()
}
}
@@ -566,8 +568,10 @@ func (h *handshake) execute() *tcpip.Error {
}, synOpts)
for h.state != handshakeCompleted {
+ // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
+ // throughout handshake processing).
h.ep.mu.Unlock()
- index, _ := s.Fetch(true)
+ index, _ := s.Fetch(true /* block */)
h.ep.mu.Lock()
switch index {
@@ -767,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
// TCP header, then the kernel calculate a checksum of the
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
- } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ } else if r.RequiresTXTransportChecksum() {
xsum = header.ChecksumVV(pkt.Data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
@@ -1040,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() {
// only when the endpoint is in StateClose and we want to deliver the segment
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
- ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
- ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
if ep == nil {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */)
s.decRef()
return
}
@@ -1366,7 +1370,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
drained := e.drainDone != nil
if drained {
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
// Set up the functions that will be called when the main protocol loop
@@ -1535,7 +1541,7 @@ loop:
}
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
// We need to double check here because the notification may be
@@ -1620,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
}
for _, netProto := range netProtos {
- if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil {
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
@@ -1683,7 +1689,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
for {
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
switch v {
case newSegment:
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 98aecab9e..21162f01a 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -172,10 +172,11 @@ func (d *dispatcher) wait() {
d.wg.Wait()
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
- s := newSegment(r, id, pkt)
- if !s.parse() {
+
+ s := newIncomingSegment(id, pkt)
+ if !s.parse(pkt.RXTransportChecksumValidated) {
ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 560b4904c..a6f25896b 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
testV6Connect(t, c)
}
+func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) {
+ c := context.NewWithOpts(t, context.Options{
+ EnableV6: true,
+ MTU: defaultMTU,
+ })
+ defer c.Cleanup()
+
+ // Create a v6 endpoint but don't set the v6-only TCP option.
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index c826942e9..258f9f1bb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -721,9 +721,9 @@ func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
// by another user goroutine if not then we spin, otherwise
- // we just goto sleep on the Lock() and wait.
+ // we just go to sleep on the Lock() and wait.
if !e.mu.TryLock() {
- // If socket is owned by the user then just goto sleep
+ // If socket is owned by the user then just go to sleep
// as the lock could be held for a reasonably long time.
if atomic.LoadUint32(&e.ownedByUser) == 1 {
e.mu.Lock()
@@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
// Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
+ s := newOutgoingSegment(e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
@@ -2316,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// done yet) or the reservation was freed between the check above and
// the FindTransportEndpoint below. But rather than retry the same port
// we just skip it and move on.
- transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID())
if transEP == nil {
// ReservePort failed but there is no registered endpoint with
// demuxer. Which indicates there is at least some endpoint that has
@@ -2385,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.ID
- s.route = r.Clone()
e.sndWaker.Assert()
}
}
@@ -2451,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Queue fin segment.
- s := newSegmentFromView(&e.route, e.ID, nil)
+ s := newOutgoingSegment(e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
// Mark endpoint as closed.
@@ -2633,14 +2632,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+
+ // Expand netProtos to include v4 and v6 under dual-stack if the caller is
+ // binding to a wildcard (empty) address, and this is an IPv6 endpoint with
+ // v6only set to false.
+ if netProto == header.IPv6ProtocolNumber {
+ stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
+ alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ if alsoBindToV4 {
+ netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
}
@@ -2721,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
}
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -3080,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() {
}
func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.route.HasHardwareGSOCapability() {
e.initHardwareGSO()
- } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ } else if e.route.HasSoftwareGSOCapability() {
e.gso = &stack.GSO{
MaxSize: e.route.GSOMaxSize(),
Type: stack.GSOSW,
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b25431467..2bcc5e1c2 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() {
switch {
case epState == StateInitial || epState == StateBound:
case epState.connected() || epState.handshake():
- if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
- if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ if !e.route.HasSaveRestoreCapability() {
+ if !e.route.HasDisconncetOkCapability() {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
}
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 070b634b4..0664789da 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -30,6 +30,8 @@ import (
// The canonical way of using it is to pass the Forwarder.HandlePacket function
// to stack.SetTransportProtocolHandler.
type Forwarder struct {
+ stack *stack.Stack
+
maxInFlight int
handler func(*ForwarderRequest)
@@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
rcvWnd = DefaultReceiveBufferSize
}
return &Forwarder{
+ stack: s,
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
@@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
- s := newSegment(r, id, pkt)
+func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
- if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn {
return false
}
@@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
delete(r.forwarder.inFlight, r.segment.id)
r.forwarder.mu.Unlock()
- // If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
+ replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */)
}
// Release all resources.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 5bce73605..2329aca4b 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- p.dispatcher.queuePacket(r, ep, id, pkt)
+func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ p.dispatcher.queuePacket(ep, id, pkt)
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
@@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
- s := newSegment(r, id, pkt)
+func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
- if !s.parse() || !s.csumValid {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid {
return stack.UnknownDestinationPacketMalformed
}
if !s.flagIsSet(header.TCPFlagRst) {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(p.stack, s, stack.DefaultTOS, 0)
}
return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment, tos, ttl uint8) {
+//
+// If the passed TTL is 0, then the route's default TTL will be used.
+func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error {
+ route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
@@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) {
flags |= header.TCPFlagAck
ack = s.sequenceNumber.Add(s.logicalLen())
}
- sendTCP(&s.route, tcpFields{
+
+ if ttl == 0 {
+ ttl = route.DefaultTTL()
+ }
+
+ return sendTCP(&route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 1f9c5cf50..2091989cc 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -45,9 +46,18 @@ type segment struct {
ep *endpoint
qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
- route stack.Route `state:"manual"`
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- hdr header.TCP
+
+ // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of
+ // individual members for link/network packet info.
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ netProto tcpip.NetworkProtocolNumber
+ nicID tcpip.NICID
+ remoteLinkAddr tcpip.LinkAddress
+
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -76,11 +86,16 @@ type segment struct {
acked bool
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+ netHdr := pkt.Network()
s := &segment{
- refCnt: 1,
- id: id,
- route: r.Clone(),
+ refCnt: 1,
+ id: id,
+ srcAddr: netHdr.SourceAddress(),
+ dstAddr: netHdr.DestinationAddress(),
+ netProto: pkt.NetworkProtocolNumber,
+ nicID: pkt.NICID,
+ remoteLinkAddr: pkt.SourceLinkAddress(),
}
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
@@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB
return s
}
-func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
- route: r.Clone(),
}
s.rcvdTime = time.Now()
if len(v) != 0 {
@@ -110,7 +124,9 @@ func (s *segment) clone() *segment {
ackNumber: s.ackNumber,
flags: s.flags,
window: s.window,
- route: s.route.Clone(),
+ netProto: s.netProto,
+ nicID: s.nicID,
+ remoteLinkAddr: s.remoteLinkAddr,
viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
@@ -160,7 +176,6 @@ func (s *segment) decRef() {
panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
}
}
- s.route.Release()
}
}
@@ -198,10 +213,10 @@ func (s *segment) segMemSize() int {
//
// Returns boolean indicating if the parsing was successful.
//
-// If checksum verification is not offloaded then parse also verifies the
+// If checksum verification may not be skipped, parse also verifies the
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
-func (s *segment) parse() bool {
+func (s *segment) parse(skipChecksumValidation bool) bool {
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -220,16 +235,14 @@ func (s *segment) parse() bool {
s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
- // Query the link capabilities to decide if checksum validation is
- // required.
verifyChecksum := true
- if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ if skipChecksumValidation {
s.csumValid = true
verifyChecksum = false
}
if verifyChecksum {
s.csum = s.hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr)))
xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 6fa8d63cd..ab5fa4fb7 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -1285,6 +1285,10 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
func (s *sender) walkSACK(rcvdSeg *segment) {
+ if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 {
+ return
+ }
+
// Sort the SACK blocks. The first block is the most recent unacked
// block. The following blocks can be in arbitrary order.
sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 79646fefe..f791f8f13 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -112,6 +112,18 @@ type Headers struct {
TCPOpts []byte
}
+// Options contains options for creating a new test context.
+type Options struct {
+ // EnableV4 indicates whether IPv4 should be enabled.
+ EnableV4 bool
+
+ // EnableV6 indicates whether IPv4 should be enabled.
+ EnableV6 bool
+
+ // MTU indicates the maximum transmission unit on the link layer.
+ MTU uint32
+}
+
// Context provides an initialized Network stack and a link layer endpoint
// for use in TCP tests.
type Context struct {
@@ -154,10 +166,30 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ return NewWithOpts(t, Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: mtu,
})
+}
+
+// NewWithOpts allocates and initializes a test context containing a new
+// stack and a link-layer endpoint with specific options.
+func NewWithOpts(t *testing.T, opts Options) *Context {
+ if opts.MTU == 0 {
+ panic("MTU must be greater than 0")
+ }
+
+ stackOpts := stack.Options{
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ }
+ if opts.EnableV4 {
+ stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
+ }
+ if opts.EnableV6 {
+ stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol)
+ }
+ s := stack.New(stackOpts)
const sendBufferSize = 1 << 20 // 1 MiB
const recvBufferSize = 1 << 20 // 1 MiB
@@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- ep := channel.New(1000, mtu, "")
+ ep := channel.New(1000, opts.MTU, "")
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- opts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, ""))
if testing.Verbose() {
- wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ wep2 = sniffer.New(channel.New(1000, opts.MTU, ""))
}
opts2 := stack.NICOptions{Name: "nic2"}
if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- v4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: StackAddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
- }
-
- v6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: StackV6AddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
- }
+ var routeTable []tcpip.Route
- s.SetRouteTable([]tcpip.Route{
- {
+ if opts.EnableV4 {
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ }
+ routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
- },
- {
+ })
+ }
+
+ if opts.EnableV6 {
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ }
+ routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
NIC: 1,
- },
- })
+ })
+ }
+
+ s.SetRouteTable(routeTable)
return &Context{
t: t,
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index cdb5127ab..9bcb918bb 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -487,6 +487,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
+ if to.Port == 0 {
+ // Port 0 is an invalid port to send to.
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -1012,7 +1017,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ if r.RequiresTXTransportChecksum() &&
(!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
@@ -1382,10 +1387,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
// omitted the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
-func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool {
- if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
- (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
+ if !pkt.RXTransportChecksumValidated &&
+ (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) {
+ netHdr := pkt.Network()
+ xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length())
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
}
@@ -1396,7 +1402,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
@@ -1406,7 +1412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
- if !verifyChecksum(r, hdr, pkt) {
+ if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
e.stats.ReceiveErrors.ChecksumErrors.Increment()
@@ -1437,7 +1443,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Push new packet into receive list and increment the buffer size.
packet := &udpPacket{
senderAddress: tcpip.FullAddress{
- NIC: r.NICID(),
+ NIC: pkt.NICID,
Addr: id.RemoteAddress,
Port: header.UDP(hdr).SourcePort(),
},
@@ -1447,7 +1453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvBufSize += pkt.Data.Size()
// Save any useful information from the network header to the packet.
- switch r.NetProto {
+ switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
case header.IPv6ProtocolNumber:
@@ -1457,9 +1463,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
// address. packetInfo.LocalAddr should hold a unicast address that can be
// used to respond to the incoming packet.
- packet.packetInfo.LocalAddr = r.LocalAddress
- packet.packetInfo.DestinationAddr = r.LocalAddress
- packet.packetInfo.NIC = r.NICID()
+ localAddr := pkt.Network().DestinationAddress()
+ packet.packetInfo.LocalAddr = localAddr
+ packet.packetInfo.DestinationAddr = localAddr
+ packet.packetInfo.NIC = pkt.NICID
packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 3ae6cc221..14e4648cd 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
- route: r,
id: id,
pkt: pkt,
})
@@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p
// it via CreateEndpoint.
type ForwarderRequest struct {
stack *stack.Stack
- route *stack.Route
id stack.TransportEndpointID
pkt *stack.PacketBuffer
}
@@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ netHdr := r.pkt.Network()
+ route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return nil, err
+ }
+ route.ResolveWith(r.pkt.SourceLinkAddress())
+
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
+ route.Release()
return nil, err
}
ep.ID = r.id
- ep.route = r.route.Clone()
+ ep.route = route
ep.dstPort = r.id.RemotePort
- ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto}
- ep.RegisterNICID = r.route.NICID()
+ ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
+ ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
@@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.rcvReady = true
ep.rcvMu.Unlock()
- ep.HandlePacket(r.route, r.id, r.pkt)
+ ep.HandlePacket(r.id, r.pkt)
return ep, nil
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index da5b1deb2..91420edd3 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets that are targeted at this
// protocol but don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ p.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return stack.UnknownDestinationPacketMalformed
}
- if !verifyChecksum(r, hdr, pkt) {
- r.Stack().Stats().UDP.ChecksumErrors.Increment()
+ if !verifyChecksum(hdr, pkt) {
+ p.stack.Stats().UDP.ChecksumErrors.Increment()
return stack.UnknownDestinationPacketMalformed
}
diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go
index 5c4b9e8e9..a38ffc19d 100644
--- a/pkg/unet/unet_test.go
+++ b/pkg/unet/unet_test.go
@@ -53,40 +53,40 @@ func randomFilename() (string, error) {
func TestConnectFailure(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
if _, err := Connect(name, false); err == nil {
- t.Fatalf("connect was successful, expected err")
+ t.Fatalf("Connect was successful, expected err")
}
}
func TestBindFailure(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
defer ss.Close()
if _, err = BindAndListen(name, false); err == nil {
- t.Fatalf("second bind succeeded, expected non-nil err")
+ t.Fatalf("Second bind succeeded, expected non-nil err")
}
}
func TestMultipleAccept(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
defer ss.Close()
@@ -99,7 +99,8 @@ func TestMultipleAccept(t *testing.T) {
defer wg.Done()
s, err := Connect(name, false)
if err != nil {
- t.Fatalf("connect failed, got err %v expected nil", err)
+ t.Errorf("Connect failed, got err %v expected nil", err)
+ return
}
s.Close()
}()
@@ -109,7 +110,7 @@ func TestMultipleAccept(t *testing.T) {
for i := 0; i < backlog; i++ {
s, err := ss.Accept()
if err != nil {
- t.Errorf("accept failed, got err %v expected nil", err)
+ t.Errorf("Accept failed, got err %v expected nil", err)
continue
}
s.Close()
@@ -119,35 +120,35 @@ func TestMultipleAccept(t *testing.T) {
func TestServerClose(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
// Make sure the first close succeeds.
if err := ss.Close(); err != nil {
- t.Fatalf("first close failed, got err %v expected nil", err)
+ t.Fatalf("First close failed, got err %v expected nil", err)
}
// The second one should fail.
if err := ss.Close(); err == nil {
- t.Fatalf("second close succeeded, expected non-nil err")
+ t.Fatalf("Second close succeeded, expected non-nil err")
}
}
func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
// Bind a server.
ss, err := BindAndListen(name, packet)
if err != nil {
- t.Fatalf("error binding, got %v expected nil", err)
+ t.Fatalf("Error binding, got %v expected nil", err)
}
defer ss.Close()
@@ -165,7 +166,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
// Connect the client.
client, err := Connect(name, packet)
if err != nil {
- t.Fatalf("error connecting, got %v expected nil", err)
+ t.Fatalf("Error connecting, got %v expected nil", err)
}
// Grab the server handle.
@@ -173,7 +174,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
case server := <-acceptSocket:
return server, client
case err := <-acceptErr:
- t.Fatalf("accept error: %v", err)
+ t.Fatalf("Accept error: %v", err)
}
panic("unreachable")
}
@@ -186,17 +187,17 @@ func TestSendRecv(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'b'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -211,17 +212,17 @@ func TestSymmetric(t *testing.T) {
// Write on the server.
w := server.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the client.
b := [][]byte{{'b'}}
r := client.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -233,13 +234,13 @@ func TestPacket(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Write on the client again.
w = client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the server.
@@ -249,19 +250,19 @@ func TestPacket(t *testing.T) {
b := [][]byte{{'b', 'b'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Do it again.
r = server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -271,12 +272,12 @@ func TestClose(t *testing.T) {
// Make sure the first close succeeds.
if err := client.Close(); err != nil {
- t.Fatalf("first close failed, got err %v expected nil", err)
+ t.Fatalf("First close failed, got err %v expected nil", err)
}
// The second one should fail.
if err := client.Close(); err == nil {
- t.Fatalf("second close succeeded, expected non-nil err")
+ t.Fatalf("Second close succeeded, expected non-nil err")
}
}
@@ -294,17 +295,17 @@ func TestNonBlockingSend(t *testing.T) {
// We're good. That's what we wanted.
blockCount++
} else {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err)
}
}
}
if blockCount == 1000 {
// Shouldn't have _always_ blocked.
- t.Fatalf("socket always blocked!")
+ t.Fatalf("Socket always blocked!")
} else if blockCount == 0 {
// Should have started blocking eventually.
- t.Fatalf("socket never blocked!")
+ t.Fatalf("Socket never blocked!")
}
}
@@ -319,25 +320,25 @@ func TestNonBlockingRecv(t *testing.T) {
// Expected to block immediately.
_, err := r.ReadVec(b)
if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
- t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ t.Fatalf("Read didn't block, got err %v expected blocking err", err)
}
// Put some data in the pipe.
w := server.Writer(false)
if n, err := w.WriteVec(b); n != 1 || err != nil {
- t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err)
}
// Expect it not to block.
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err)
}
// Expect it to return a block error again.
r = client.Reader(false)
_, err = r.ReadVec(b)
if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
- t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ t.Fatalf("Read didn't block, got err %v expected blocking err", err)
}
}
@@ -349,17 +350,17 @@ func TestRecvVectors(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'c'}, {'c'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 2 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err)
}
if b[0][0] != 'a' || b[1][0] != 'b' {
- t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0])
+ t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0])
}
}
@@ -371,17 +372,17 @@ func TestSendVectors(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'c', 'c'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 2 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err)
}
if b[0][0] != 'a' || b[0][1] != 'b' {
- t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1])
+ t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1])
}
}
@@ -394,23 +395,23 @@ func TestSendFDsNotEnabled(t *testing.T) {
w := server.Writer(true)
w.PackFDs(0, 1, 2)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the client, without enabling FDs.
b := [][]byte{{'b'}}
r := client.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Make sure the FDs are not received.
fds, err := r.ExtractFDs()
if len(fds) != 0 || err != nil {
- t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err)
+ t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err)
}
}
@@ -418,7 +419,7 @@ func sendFDs(t *testing.T, s *Socket, fds []int) {
w := s.Writer(true)
w.PackFDs(fds...)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
}
@@ -428,7 +429,7 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
// Count the number of FDs.
preEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
// Read on the client.
@@ -438,31 +439,31 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
r.EnableFDs(enableSize)
}
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Count the new number of FDs.
postEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
if len(preEntries)+expected != len(postEntries) {
- t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries))
+ t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries))
}
// Make sure the FDs are there.
fds, err := r.ExtractFDs()
if len(fds) != expected || err != nil {
- t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected)
+ t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected)
}
// Make sure they are different from the originals.
for i := 0; i < len(fds); i++ {
if fds[i] == origFDs[i] {
- t.Errorf("got original fd for index %d, expected different", i)
+ t.Errorf("Got original fd for index %d, expected different", i)
}
}
@@ -480,10 +481,10 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
// Make sure the count is back to normal.
finalEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
if len(finalEntries) != len(preEntries) {
- t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries))
+ t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries))
}
}
@@ -567,7 +568,7 @@ func TestGetPeerCred(t *testing.T) {
}
if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) {
- t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil)
+ t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil)
}
}
@@ -594,53 +595,53 @@ func TestGetPeerCredFailure(t *testing.T) {
want := "bad file descriptor"
if _, err := s.GetPeerCred(); err == nil || err.Error() != want {
- t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want)
+ t.Errorf("s.GetPeerCred() = %v, want = %s", err, want)
}
}
func TestAcceptClosed(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
if err := ss.Close(); err != nil {
- t.Fatalf("close failed, got err %v expected nil", err)
+ t.Fatalf("Close failed, got err %v expected nil", err)
}
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
}
func TestCloseAfterAcceptStart(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
+ defer wg.Done()
time.Sleep(50 * time.Millisecond)
if err := ss.Close(); err != nil {
- t.Fatalf("close failed, got err %v expected nil", err)
+ t.Errorf("Close failed, got err %v expected nil", err)
}
- wg.Done()
}()
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
wg.Wait()
@@ -649,28 +650,28 @@ func TestCloseAfterAcceptStart(t *testing.T) {
func TestReleaseAfterAcceptStart(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
+ defer wg.Done()
time.Sleep(50 * time.Millisecond)
fd, err := ss.Release()
if err != nil {
- t.Fatalf("Release failed, got err %v expected nil", err)
+ t.Errorf("Release failed, got err %v expected nil", err)
}
syscall.Close(fd)
- wg.Done()
}()
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
wg.Wait()
@@ -688,7 +689,7 @@ func TestControlMessage(t *testing.T) {
cm.PackFDs(want...)
got, err := cm.ExtractFDs()
if err != nil || !reflect.DeepEqual(got, want) {
- t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil)
+ t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil)
}
}
}
@@ -705,11 +706,13 @@ func benchmarkSendRecv(b *testing.B, packet bool) {
for i := 0; i < b.N; i++ {
n, err := server.Read(buf)
if n != 1 || err != nil {
- b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err)
+ b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err)
+ return
}
n, err = server.Write(buf)
if n != 1 || err != nil {
- b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err)
+ b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err)
+ return
}
}
}()