diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/ring0/BUILD (renamed from pkg/sentry/platform/ring0/BUILD) | 10 | ||||
-rw-r--r-- | pkg/ring0/aarch64.go (renamed from pkg/sentry/platform/ring0/aarch64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/defs.go (renamed from pkg/sentry/platform/ring0/defs.go) | 2 | ||||
-rw-r--r-- | pkg/ring0/defs_amd64.go (renamed from pkg/sentry/platform/ring0/defs_amd64.go) | 1 | ||||
-rw-r--r-- | pkg/ring0/defs_arm64.go (renamed from pkg/sentry/platform/ring0/defs_arm64.go) | 1 | ||||
-rw-r--r-- | pkg/ring0/entry_amd64.go (renamed from pkg/sentry/platform/ring0/entry_amd64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/entry_amd64.s (renamed from pkg/sentry/platform/ring0/entry_amd64.s) | 0 | ||||
-rw-r--r-- | pkg/ring0/entry_arm64.go (renamed from pkg/sentry/platform/ring0/entry_arm64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/entry_arm64.s (renamed from pkg/sentry/platform/ring0/entry_arm64.s) | 0 | ||||
-rw-r--r-- | pkg/ring0/gen_offsets/BUILD (renamed from pkg/sentry/platform/ring0/gen_offsets/BUILD) | 8 | ||||
-rw-r--r-- | pkg/ring0/gen_offsets/main.go (renamed from pkg/sentry/platform/ring0/gen_offsets/main.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/kernel.go (renamed from pkg/sentry/platform/ring0/kernel.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/kernel_amd64.go (renamed from pkg/sentry/platform/ring0/kernel_amd64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/kernel_arm64.go (renamed from pkg/sentry/platform/ring0/kernel_arm64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/kernel_unsafe.go (renamed from pkg/sentry/platform/ring0/kernel_unsafe.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/lib_amd64.go (renamed from pkg/sentry/platform/ring0/lib_amd64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/lib_amd64.s (renamed from pkg/sentry/platform/ring0/lib_amd64.s) | 0 | ||||
-rw-r--r-- | pkg/ring0/lib_arm64.go (renamed from pkg/sentry/platform/ring0/lib_arm64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/lib_arm64.s (renamed from pkg/sentry/platform/ring0/lib_arm64.s) | 0 | ||||
-rw-r--r-- | pkg/ring0/offsets_amd64.go (renamed from pkg/sentry/platform/ring0/offsets_amd64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/offsets_arm64.go (renamed from pkg/sentry/platform/ring0/offsets_arm64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/BUILD (renamed from pkg/sentry/platform/ring0/pagetables/BUILD) | 8 | ||||
-rw-r--r-- | pkg/ring0/pagetables/allocator.go (renamed from pkg/sentry/platform/ring0/pagetables/allocator.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/allocator_unsafe.go (renamed from pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables.go) | 60 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables_aarch64.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go) | 11 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables_amd64.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go) | 2 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables_amd64_test.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables_arm64.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go) | 1 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables_arm64_test.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables_test.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables_test.go) | 5 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pagetables_x86.go (renamed from pkg/sentry/platform/ring0/pagetables/pagetables_x86.go) | 5 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pcids.go (renamed from pkg/sentry/platform/ring0/pagetables/pcids.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pcids_aarch64.go (renamed from pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pcids_aarch64.s (renamed from pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/pcids_x86.go (renamed from pkg/sentry/platform/ring0/pagetables/pcids_x86.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/pagetables/walker_amd64.go (renamed from pkg/sentry/platform/ring0/pagetables/walker_amd64.go) | 142 | ||||
-rw-r--r-- | pkg/ring0/pagetables/walker_arm64.go (renamed from pkg/sentry/platform/ring0/pagetables/walker_arm64.go) | 117 | ||||
-rw-r--r-- | pkg/ring0/pagetables/walker_generic.go | 110 | ||||
-rw-r--r-- | pkg/ring0/ring0.go (renamed from pkg/sentry/platform/ring0/ring0.go) | 0 | ||||
-rw-r--r-- | pkg/ring0/x86.go (renamed from pkg/sentry/platform/ring0/x86.go) | 0 | ||||
-rw-r--r-- | pkg/sentry/control/proc.go | 13 | ||||
-rw-r--r-- | pkg/sentry/kernel/pipe/vfs.go | 10 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/BUILD | 12 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/address_space.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/bluepill.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/bluepill_allocator.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/bluepill_amd64.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/bluepill_arm64.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/context.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/kvm.go | 4 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/kvm_amd64.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/kvm_amd64_test.go | 4 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/kvm_arm64.go | 2 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/kvm_test.go | 4 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/machine.go | 4 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/machine_amd64.go | 4 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/machine_arm64.go | 4 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/machine_arm64_unsafe.go | 4 | ||||
-rw-r--r-- | pkg/sentry/platform/kvm/physical_map.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 8 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/error.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv4.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 231 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 112 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 137 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/socketops.go | 63 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 77 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_options.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 54 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/cubic.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rack.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/reno.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 34 |
91 files changed, 989 insertions, 514 deletions
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/ring0/BUILD index 2852b7387..d1b14efdb 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/ring0/BUILD @@ -43,16 +43,16 @@ arch_genrule( name = "entry_impl_amd64", srcs = ["entry_amd64.s"], outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], + cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", + tools = ["//pkg/ring0/gen_offsets"], ) arch_genrule( name = "entry_impl_arm64", srcs = ["entry_arm64.s"], outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], + cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", + tools = ["//pkg/ring0/gen_offsets"], ) go_library( @@ -77,9 +77,9 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/cpuid", + "//pkg/ring0/pagetables", "//pkg/safecopy", "//pkg/sentry/arch", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/ring0/aarch64.go index 3bda594f9..3bda594f9 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/ring0/aarch64.go diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/ring0/defs.go index f9765771e..e2561e4c2 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/ring0/defs.go @@ -15,8 +15,8 @@ package ring0 import ( + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) // Kernel is a global kernel object. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go index 7a2275558..ceddf719d 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/ring0/defs_amd64.go @@ -17,7 +17,6 @@ package ring0 import ( - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/ring0/defs_arm64.go index a014dcbc0..dcb255fc8 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/ring0/defs_arm64.go @@ -17,7 +17,6 @@ package ring0 import ( - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/ring0/entry_amd64.go index d87b1fd00..d87b1fd00 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.go +++ b/pkg/ring0/entry_amd64.go diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/ring0/entry_amd64.s index f59747df3..f59747df3 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/ring0/entry_amd64.s diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/ring0/entry_arm64.go index 62a93f3d6..62a93f3d6 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.go +++ b/pkg/ring0/entry_arm64.go diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/ring0/entry_arm64.s index b2bb18257..b2bb18257 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/ring0/entry_arm64.s diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD index a9703baf6..15b93d61c 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/ring0/gen_offsets/BUILD @@ -7,14 +7,14 @@ go_template_instance( name = "defs_impl_arm64", out = "defs_impl_arm64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs_arm64", + template = "//pkg/ring0:defs_arm64", ) go_template_instance( name = "defs_impl_amd64", out = "defs_impl_amd64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs_amd64", + template = "//pkg/ring0:defs_amd64", ) go_binary( @@ -28,13 +28,13 @@ go_binary( # pass the sentry deps test. system_malloc = True, visibility = [ + "//pkg/ring0:__pkg__", "//pkg/sentry/platform/kvm:__pkg__", - "//pkg/sentry/platform/ring0:__pkg__", ], deps = [ "//pkg/cpuid", + "//pkg/ring0/pagetables", "//pkg/sentry/arch", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/ring0/gen_offsets/main.go b/pkg/ring0/gen_offsets/main.go index a4927da2f..a4927da2f 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/main.go +++ b/pkg/ring0/gen_offsets/main.go diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/ring0/kernel.go index 292f9d0cc..292f9d0cc 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/ring0/kernel.go diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 36a60700e..36a60700e 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go index c05284641..c05284641 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/ring0/kernel_arm64.go diff --git a/pkg/sentry/platform/ring0/kernel_unsafe.go b/pkg/ring0/kernel_unsafe.go index 16955ad91..16955ad91 100644 --- a/pkg/sentry/platform/ring0/kernel_unsafe.go +++ b/pkg/ring0/kernel_unsafe.go diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go index 0ec5c3bc5..0ec5c3bc5 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.go +++ b/pkg/ring0/lib_amd64.go diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s index 2fe83568a..2fe83568a 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.s +++ b/pkg/ring0/lib_amd64.s diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go index a490bf3af..a490bf3af 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/ring0/lib_arm64.go diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s index e39b32841..e39b32841 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/ring0/lib_arm64.s diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go index ca4075b09..ca4075b09 100644 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ b/pkg/ring0/offsets_amd64.go diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/ring0/offsets_arm64.go index 164db6d5a..164db6d5a 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/ring0/offsets_arm64.go diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/ring0/pagetables/BUILD index 9e3539e4c..65a978cbb 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/ring0/pagetables/BUILD @@ -9,7 +9,10 @@ package(licenses = ["notice"]) # architecture builds. go_template( name = "generic_walker_%s" % arch, - srcs = ["walker_%s.go" % arch], + srcs = [ + "walker_generic.go", + "walker_%s.go" % arch, + ], opt_types = [ "Visitor", ], @@ -50,6 +53,7 @@ go_library( "pcids_x86.go", "walker_amd64.go", "walker_arm64.go", + "walker_generic.go", ":walker_empty_amd64", ":walker_empty_arm64", ":walker_lookup_amd64", @@ -60,8 +64,8 @@ go_library( ":walker_unmap_arm64", ], visibility = [ + "//pkg/ring0:__subpackages__", "//pkg/sentry/platform/kvm:__subpackages__", - "//pkg/sentry/platform/ring0:__subpackages__", ], deps = [ "//pkg/sync", diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/ring0/pagetables/allocator.go index 8d75b7599..8d75b7599 100644 --- a/pkg/sentry/platform/ring0/pagetables/allocator.go +++ b/pkg/ring0/pagetables/allocator.go diff --git a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go b/pkg/ring0/pagetables/allocator_unsafe.go index d08bfdeb3..d08bfdeb3 100644 --- a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go +++ b/pkg/ring0/pagetables/allocator_unsafe.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 7605d0cb2..8c0a6aa82 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -60,6 +60,7 @@ type PageTables struct { // Init initializes a set of PageTables. // +// +checkescape:hard,stack //go:nosplit func (p *PageTables) Init(allocator Allocator) { p.Allocator = allocator @@ -92,7 +93,6 @@ func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uin } p.InitArch(a) - return p } @@ -112,7 +112,7 @@ type mapVisitor struct { // visit is used for map. // //go:nosplit -func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { p := v.physical + (start - uintptr(v.target)) if pte.Valid() && (pte.Address() != p || pte.Opts() != v.opts) { v.prev = true @@ -122,9 +122,10 @@ func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // install a valid entry here, however we must zap any existing // entry to ensure this happens. pte.Clear() - return + return true } pte.Set(p, v.opts) + return true } //go:nosplit @@ -140,7 +141,6 @@ func (*mapVisitor) requiresSplit() bool { return true } // Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack -// //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { if p.readOnlyShared { @@ -158,9 +158,6 @@ func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physic length = p.upperStart - uintptr(addr) } } - if !opts.AccessType.Any() { - return p.Unmap(addr, length) - } w := mapWalker{ pageTables: p, visitor: mapVisitor{ @@ -187,9 +184,10 @@ func (*unmapVisitor) requiresSplit() bool { return true } // visit unmaps the given entry. // //go:nosplit -func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { pte.Clear() v.count++ + return true } // Unmap unmaps the given range. @@ -199,7 +197,6 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack -// //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { if p.readOnlyShared { @@ -241,8 +238,9 @@ func (*emptyVisitor) requiresSplit() bool { return false } // visit unmaps the given entry. // //go:nosplit -func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { v.count++ + return true } // IsEmpty checks if the given range is empty. @@ -250,7 +248,6 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) { // Precondition: addr & length must be page-aligned. // // +checkescape:hard,stack -// //go:nosplit func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool { w := emptyWalker{ @@ -262,20 +259,28 @@ func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool { // lookupVisitor is used for lookup. type lookupVisitor struct { - target uintptr // Input. - physical uintptr // Output. - opts MapOpts // Output. + target uintptr // Input & Output. + findFirst bool // Input. + physical uintptr // Output. + size uintptr // Output. + opts MapOpts // Output. } // visit matches the given address. // //go:nosplit -func (v *lookupVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *lookupVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { if !pte.Valid() { - return + // If looking for the first, then we just keep iterating until + // we find a valid entry. + return v.findFirst } - v.physical = pte.Address() + (start - uintptr(v.target)) + // Is this within the current range? + v.target = start + v.physical = pte.Address() + v.size = (align + 1) v.opts = pte.Opts() + return false } //go:nosplit @@ -286,20 +291,29 @@ func (*lookupVisitor) requiresSplit() bool { return false } // Lookup returns the physical address for the given virtual address. // -// +checkescape:hard,stack +// If findFirst is true, then the next valid address after addr is returned. +// If findFirst is false, then only a mapping for addr will be returned. +// +// Note that if size is zero, then no matching entry was found. // +// +checkescape:hard,stack //go:nosplit -func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) { +func (p *PageTables) Lookup(addr usermem.Addr, findFirst bool) (virtual usermem.Addr, physical, size uintptr, opts MapOpts) { mask := uintptr(usermem.PageSize - 1) - offset := uintptr(addr) & mask + addr &^= usermem.Addr(mask) w := lookupWalker{ pageTables: p, visitor: lookupVisitor{ - target: uintptr(addr &^ usermem.Addr(mask)), + target: uintptr(addr), + findFirst: findFirst, }, } - w.iterateRange(uintptr(addr), uintptr(addr)+1) - return w.visitor.physical + offset, w.visitor.opts + end := ^usermem.Addr(0) &^ usermem.Addr(mask) + if !findFirst { + end = addr + 1 + } + w.iterateRange(uintptr(addr), uintptr(end)) + return usermem.Addr(w.visitor.target), w.visitor.physical, w.visitor.size, w.visitor.opts } // MarkReadOnlyShared marks the pagetables read-only and can be shared. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/ring0/pagetables/pagetables_aarch64.go index 520161755..163a3aea3 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/ring0/pagetables/pagetables_aarch64.go @@ -156,12 +156,7 @@ func (p *PTE) IsSect() bool { // //go:nosplit func (p *PTE) Set(addr uintptr, opts MapOpts) { - if !opts.AccessType.Any() { - p.Clear() - return - } - v := (addr &^ optionMask) | protDefault | nG | readOnly - + v := (addr &^ optionMask) | nG | readOnly | protDefault if p.IsSect() { // Note that this is inherited from the previous instance. Set // does not change the value of Sect. See above. @@ -169,6 +164,10 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) { } else { v |= typePage } + if !opts.AccessType.Any() { + // Leave as non-valid if no access is available. + v &^= pteValid + } if opts.Global { v = v &^ nG diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/ring0/pagetables/pagetables_amd64.go index 4bdde8448..a217f404c 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/ring0/pagetables/pagetables_amd64.go @@ -43,6 +43,7 @@ const ( // InitArch does some additional initialization related to the architecture. // +// +checkescape:hard,stack //go:nosplit func (p *PageTables) InitArch(allocator Allocator) { if p.upperSharedPageTables != nil { @@ -50,6 +51,7 @@ func (p *PageTables) InitArch(allocator Allocator) { } } +//go:nosplit func pgdIndex(upperStart uintptr) uintptr { if upperStart&(pgdSize-1) != 0 { panic("upperStart should be pgd size aligned") diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go b/pkg/ring0/pagetables/pagetables_amd64_test.go index 54e8e554f..54e8e554f 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go +++ b/pkg/ring0/pagetables/pagetables_amd64_test.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/ring0/pagetables/pagetables_arm64.go index ad0e30c88..fef7a0fd1 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/ring0/pagetables/pagetables_arm64.go @@ -44,6 +44,7 @@ const ( // InitArch does some additional initialization related to the architecture. // +// +checkescape:hard,stack //go:nosplit func (p *PageTables) InitArch(allocator Allocator) { if p.upperSharedPageTables != nil { diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/ring0/pagetables/pagetables_arm64_test.go index 2f73d424f..2f73d424f 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go +++ b/pkg/ring0/pagetables/pagetables_arm64_test.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/ring0/pagetables/pagetables_test.go index 5c88d087d..772f4fc5e 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go +++ b/pkg/ring0/pagetables/pagetables_test.go @@ -34,7 +34,7 @@ type checkVisitor struct { failed string // Output. } -func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { v.found = append(v.found, mapping{ start: start, length: align + 1, @@ -43,7 +43,7 @@ func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { }) if v.failed != "" { // Don't keep looking for errors. - return + return false } if v.current >= len(v.expected) { @@ -58,6 +58,7 @@ func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { v.failed = "opts didn't match" } v.current++ + return true } func (*checkVisitor) requiresAlloc() bool { return false } diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/ring0/pagetables/pagetables_x86.go index 157438d9b..32edd2f0a 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go +++ b/pkg/ring0/pagetables/pagetables_x86.go @@ -137,7 +137,10 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) { p.Clear() return } - v := (addr &^ optionMask) | present | accessed + v := (addr &^ optionMask) + if opts.AccessType.Any() { + v |= present | accessed + } if opts.User { v |= user } diff --git a/pkg/sentry/platform/ring0/pagetables/pcids.go b/pkg/ring0/pagetables/pcids.go index 964496aac..964496aac 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids.go +++ b/pkg/ring0/pagetables/pcids.go diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go b/pkg/ring0/pagetables/pcids_aarch64.go index fbfd41d83..fbfd41d83 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go +++ b/pkg/ring0/pagetables/pcids_aarch64.go diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s b/pkg/ring0/pagetables/pcids_aarch64.s index e9d62d768..e9d62d768 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s +++ b/pkg/ring0/pagetables/pcids_aarch64.s diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/ring0/pagetables/pcids_x86.go index 91fc5e8dd..91fc5e8dd 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/ring0/pagetables/pcids_x86.go diff --git a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go b/pkg/ring0/pagetables/walker_amd64.go index 8f9dacd93..eb4fbcc31 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go +++ b/pkg/ring0/pagetables/walker_amd64.go @@ -16,104 +16,10 @@ package pagetables -// Visitor is a generic type. -type Visitor interface { - // visit is called on each PTE. - visit(start uintptr, pte *PTE, align uintptr) - - // requiresAlloc indicates that new entries should be allocated within - // the walked range. - requiresAlloc() bool - - // requiresSplit indicates that entries in the given range should be - // split if they are huge or jumbo pages. - requiresSplit() bool -} - -// Walker walks page tables. -type Walker struct { - // pageTables are the tables to walk. - pageTables *PageTables - - // Visitor is the set of arguments. - visitor Visitor -} - -// iterateRange iterates over all appropriate levels of page tables for the given range. -// -// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The -// exception is super pages. If a valid super page (huge or jumbo) cannot be -// installed, then the walk will continue to individual entries. -// -// This algorithm will attempt to maximize the use of super pages whenever -// possible. Whether a super page is provided will be clear through the range -// provided in the callback. -// -// Note that if requiresAlloc is true, then no gaps will be present. However, -// if alloc is not set, then the iteration will likely be full of gaps. -// -// Note that this function should generally be avoided in favor of Map, Unmap, -// etc. when not necessary. -// -// Precondition: start must be page-aligned. -// -// Precondition: start must be less than end. -// -// Precondition: If requiresAlloc is true, then start and end should not span -// non-canonical ranges. If they do, a panic will result. -// -//go:nosplit -func (w *Walker) iterateRange(start, end uintptr) { - if start%pteSize != 0 { - panic("unaligned start") - } - if end < start { - panic("start > end") - } - if start < lowerTop { - if end <= lowerTop { - w.iterateRangeCanonical(start, end) - } else if end > lowerTop && end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - w.iterateRangeCanonical(upperBottom, end) - } - } else if start < upperBottom { - if end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(upperBottom, end) - } - } else { - w.iterateRangeCanonical(start, end) - } -} - -// next returns the next address quantized by the given size. -// -//go:nosplit -func next(start uintptr, size uintptr) uintptr { - start &= ^(size - 1) - start += size - return start -} - // iterateRangeCanonical walks a canonical range. // //go:nosplit -func (w *Walker) iterateRangeCanonical(start, end uintptr) { +func (w *Walker) iterateRangeCanonical(start, end uintptr) bool { for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { var ( pgdEntry = &w.pageTables.root[pgdIndex] @@ -127,10 +33,10 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // Allocate a new pgd. - pudEntries = w.pageTables.Allocator.NewPTEs() + pudEntries = w.pageTables.Allocator.NewPTEs() // escapes: depends on allocator. pgdEntry.setPageTable(w.pageTables, pudEntries) } else { - pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) // escapes: see above. } // Map the next level. @@ -155,7 +61,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // new page for the pmd. if start&(pudSize-1) == 0 && end-start >= pudSize { pudEntry.SetSuper() - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } if pudEntry.Valid() { start = next(start, pudSize) continue @@ -163,14 +71,14 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // Allocate a new pud. - pmdEntries = w.pageTables.Allocator.NewPTEs() + pmdEntries = w.pageTables.Allocator.NewPTEs() // escapes: see above. pudEntry.setPageTable(w.pageTables, pmdEntries) } else if pudEntry.IsSuper() { // Does this page need to be split? if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) { // Install the relevant entries. - pmdEntries = w.pageTables.Allocator.NewPTEs() + pmdEntries = w.pageTables.Allocator.NewPTEs() // escapes: see above. for index := uint16(0); index < entriesPerPage; index++ { pmdEntries[index].SetSuper() pmdEntries[index].Set( @@ -180,7 +88,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pudEntry.setPageTable(w.pageTables, pmdEntries) } else { // A super page to be checked directly. - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } // Might have been cleared. if !pudEntry.Valid() { @@ -192,7 +102,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { continue } } else { - pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) // escapes: see above. } // Map the next level, since this is valid. @@ -216,7 +126,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // As above, we can skip allocating a new page. if start&(pmdSize-1) == 0 && end-start >= pmdSize { pmdEntry.SetSuper() - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } if pmdEntry.Valid() { start = next(start, pmdSize) continue @@ -224,7 +136,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // Allocate a new pmd. - pteEntries = w.pageTables.Allocator.NewPTEs() + pteEntries = w.pageTables.Allocator.NewPTEs() // escapes: see above. pmdEntry.setPageTable(w.pageTables, pteEntries) } else if pmdEntry.IsSuper() { @@ -240,7 +152,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pmdEntry.setPageTable(w.pageTables, pteEntries) } else { // A huge page to be checked directly. - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } // Might have been cleared. if !pmdEntry.Valid() { @@ -252,7 +166,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { continue } } else { - pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) // escapes: see above. } // Map the next level, since this is valid. @@ -269,11 +183,10 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // At this point, we are guaranteed that start%pteSize == 0. - w.visitor.visit(uintptr(start), pteEntry, pteSize-1) - if !pteEntry.Valid() { - if w.visitor.requiresAlloc() { - panic("PTE not set after iteration with requiresAlloc!") - } + if !w.visitor.visit(uintptr(start&^(pteSize-1)), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { clearPTEEntries++ } @@ -285,7 +198,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // Check if we no longer need this page. if clearPTEEntries == entriesPerPage { pmdEntry.Clear() - w.pageTables.Allocator.FreePTEs(pteEntries) + w.pageTables.Allocator.FreePTEs(pteEntries) // escapes: see above. clearPMDEntries++ } } @@ -293,7 +206,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // Check if we no longer need this page. if clearPMDEntries == entriesPerPage { pudEntry.Clear() - w.pageTables.Allocator.FreePTEs(pmdEntries) + w.pageTables.Allocator.FreePTEs(pmdEntries) // escapes: see above. clearPUDEntries++ } } @@ -301,7 +214,8 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // Check if we no longer need this page. if clearPUDEntries == entriesPerPage { pgdEntry.Clear() - w.pageTables.Allocator.FreePTEs(pudEntries) + w.pageTables.Allocator.FreePTEs(pudEntries) // escapes: see above. } } + return true } diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/ring0/pagetables/walker_arm64.go index c261d393a..5ed881c7a 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go +++ b/pkg/ring0/pagetables/walker_arm64.go @@ -16,104 +16,10 @@ package pagetables -// Visitor is a generic type. -type Visitor interface { - // visit is called on each PTE. - visit(start uintptr, pte *PTE, align uintptr) - - // requiresAlloc indicates that new entries should be allocated within - // the walked range. - requiresAlloc() bool - - // requiresSplit indicates that entries in the given range should be - // split if they are huge or jumbo pages. - requiresSplit() bool -} - -// Walker walks page tables. -type Walker struct { - // pageTables are the tables to walk. - pageTables *PageTables - - // Visitor is the set of arguments. - visitor Visitor -} - -// iterateRange iterates over all appropriate levels of page tables for the given range. -// -// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The -// exception is sect pages. If a valid sect page (huge or jumbo) cannot be -// installed, then the walk will continue to individual entries. -// -// This algorithm will attempt to maximize the use of sect pages whenever -// possible. Whether a sect page is provided will be clear through the range -// provided in the callback. -// -// Note that if requiresAlloc is true, then no gaps will be present. However, -// if alloc is not set, then the iteration will likely be full of gaps. -// -// Note that this function should generally be avoided in favor of Map, Unmap, -// etc. when not necessary. -// -// Precondition: start must be page-aligned. -// -// Precondition: start must be less than end. -// -// Precondition: If requiresAlloc is true, then start and end should not span -// non-canonical ranges. If they do, a panic will result. -// -//go:nosplit -func (w *Walker) iterateRange(start, end uintptr) { - if start%pteSize != 0 { - panic("unaligned start") - } - if end < start { - panic("start > end") - } - if start < lowerTop { - if end <= lowerTop { - w.iterateRangeCanonical(start, end) - } else if end > lowerTop && end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - w.iterateRangeCanonical(upperBottom, end) - } - } else if start < upperBottom { - if end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(upperBottom, end) - } - } else { - w.iterateRangeCanonical(start, end) - } -} - -// next returns the next address quantized by the given size. -// -//go:nosplit -func next(start uintptr, size uintptr) uintptr { - start &= ^(size - 1) - start += size - return start -} - // iterateRangeCanonical walks a canonical range. // //go:nosplit -func (w *Walker) iterateRangeCanonical(start, end uintptr) { +func (w *Walker) iterateRangeCanonical(start, end uintptr) bool { pgdEntryIndex := w.pageTables.root if start >= upperBottom { pgdEntryIndex = w.pageTables.archPageTables.root @@ -160,7 +66,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // new page for the pmd. if start&(pudSize-1) == 0 && end-start >= pudSize { pudEntry.SetSect() - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } if pudEntry.Valid() { start = next(start, pudSize) continue @@ -185,7 +93,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pudEntry.setPageTable(w.pageTables, pmdEntries) } else { // A sect page to be checked directly. - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } // Might have been cleared. if !pudEntry.Valid() { @@ -222,7 +132,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // As above, we can skip allocating a new page. if start&(pmdSize-1) == 0 && end-start >= pmdSize { pmdEntry.SetSect() - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } if pmdEntry.Valid() { start = next(start, pmdSize) continue @@ -246,7 +158,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pmdEntry.setPageTable(w.pageTables, pteEntries) } else { // A huge page to be checked directly. - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } // Might have been cleared. if !pmdEntry.Valid() { @@ -276,7 +190,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // At this point, we are guaranteed that start%pteSize == 0. - w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !w.visitor.visit(uintptr(start), pteEntry, pteSize-1) { + return false + } if !pteEntry.Valid() { if w.visitor.requiresAlloc() { panic("PTE not set after iteration with requiresAlloc!") @@ -311,4 +227,5 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { w.pageTables.Allocator.FreePTEs(pudEntries) } } + return true } diff --git a/pkg/ring0/pagetables/walker_generic.go b/pkg/ring0/pagetables/walker_generic.go new file mode 100644 index 000000000..34fba7b84 --- /dev/null +++ b/pkg/ring0/pagetables/walker_generic.go @@ -0,0 +1,110 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagetables + +// Visitor is a generic type. +type Visitor interface { + // visit is called on each PTE. The returned boolean indicates whether + // the walk should continue. + visit(start uintptr, pte *PTE, align uintptr) bool + + // requiresAlloc indicates that new entries should be allocated within + // the walked range. + requiresAlloc() bool + + // requiresSplit indicates that entries in the given range should be + // split if they are huge or jumbo pages. + requiresSplit() bool +} + +// Walker walks page tables. +type Walker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor Visitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *Walker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func next(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/sentry/platform/ring0/ring0.go b/pkg/ring0/ring0.go index cdeb1b43a..cdeb1b43a 100644 --- a/pkg/sentry/platform/ring0/ring0.go +++ b/pkg/ring0/ring0.go diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/ring0/x86.go index 34fbc1c35..34fbc1c35 100644 --- a/pkg/sentry/platform/ring0/x86.go +++ b/pkg/ring0/x86.go diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 1d88db12f..de7a0f3ab 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -404,3 +404,16 @@ func ttyName(tty *kernel.TTY) string { } return fmt.Sprintf("pts/%d", tty.Index) } + +// ContainerUsage retrieves per-container CPU usage. +func ContainerUsage(kr *kernel.Kernel) map[string]uint64 { + cusage := make(map[string]uint64) + for _, tg := range kr.TaskSet().Root.ThreadGroups() { + // We want each tg's usage including reaped children. + cid := tg.Leader().ContainerID() + stats := tg.CPUStats() + stats.Accumulate(tg.JoinedChildCPUStats()) + cusage[cid] += uint64(stats.UserTime.Nanoseconds()) + uint64(stats.SysTime.Nanoseconds()) + } + return cusage +} diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 3b6336e94..09c0ccaf2 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -368,17 +368,15 @@ func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst }) } -// CopyOutFrom implements usermem.IO.CopyOutFrom. +// CopyOutFrom implements usermem.IO.CopyOutFrom. Note that it is the caller's +// responsibility to call fd.pipe.Notify(waiter.EventIn) after the write is +// completed. // // Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { - n, err := fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) { + return fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) { return src.ReadToBlocks(dsts) }) - if n > 0 { - fd.pipe.Notify(waiter.EventIn) - } - return n, err } // SwapUint32 implements usermem.IO.SwapUint32. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 8ce411102..b3290917e 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -45,14 +45,14 @@ go_library( "//pkg/cpuid", "//pkg/log", "//pkg/procid", + "//pkg/ring0", + "//pkg/ring0/pagetables", "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/sentry/time", "//pkg/sync", "//pkg/usermem", @@ -75,11 +75,11 @@ go_test( "requires-kvm", ], deps = [ + "//pkg/ring0", + "//pkg/ring0/pagetables", "//pkg/sentry/arch", "//pkg/sentry/platform", "//pkg/sentry/platform/kvm/testutil", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/sentry/time", "//pkg/usermem", ], @@ -89,6 +89,6 @@ genrule( name = "bluepill_impl_amd64", srcs = ["bluepill_amd64.s"], outs = ["bluepill_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/ring0/gen_offsets"], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index af5c5e191..25c21e843 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,9 +18,9 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 4b23f7803..2c970162e 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -19,9 +19,9 @@ import ( "reflect" "syscall" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) // bluepill enters guest mode. diff --git a/pkg/sentry/platform/kvm/bluepill_allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go index 9485e1301..1825edc3a 100644 --- a/pkg/sentry/platform/kvm/bluepill_allocator.go +++ b/pkg/sentry/platform/kvm/bluepill_allocator.go @@ -17,7 +17,7 @@ package kvm import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/ring0/pagetables" ) type allocator struct { diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index ddc1554d5..83a4766fb 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -19,8 +19,8 @@ package kvm import ( "syscall" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) var ( diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index f8ccb7430..0063e947b 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -20,8 +20,8 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) // dieArchSetup initializes the state for dieTrampoline. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 1f09813ba..35298135a 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -19,8 +19,8 @@ package kvm import ( "syscall" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) var ( diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 4d912769a..dbbf2a897 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -20,8 +20,8 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) // fpsimdPtr returns a fpsimd64 for the given address. diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go index 17268d127..aeae01dbd 100644 --- a/pkg/sentry/platform/kvm/context.go +++ b/pkg/sentry/platform/kvm/context.go @@ -18,10 +18,10 @@ import ( "sync/atomic" pkgcontext "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index 5979aef97..7bdf57436 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -20,9 +20,9 @@ import ( "os" "syscall" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go index 093497bc4..b9ed4a706 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64.go +++ b/pkg/sentry/platform/kvm/kvm_amd64.go @@ -18,7 +18,7 @@ package kvm import ( "gvisor.dev/gvisor/pkg/cpuid" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/ring0" ) // userRegs represents KVM user registers. diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index c0b4fd374..76fc594a0 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -19,11 +19,11 @@ package kvm import ( "testing" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) func TestSegments(t *testing.T) { diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 9db1db4e9..b73340f0e 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -17,8 +17,8 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) type kvmOneReg struct { diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index a650877d6..11ca1f0ea 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -22,11 +22,11 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index e2fffc99b..1ece1b8d8 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 8e03c310d..59c752d73 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -24,10 +24,10 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index aa2d21748..7d7857067 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -17,10 +17,10 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index a466acf4d..dca0cdb60 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -23,10 +23,10 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index f7fa2f98d..8bdec93ae 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -20,7 +20,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 94f03af48..69693f263 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2666,9 +2666,9 @@ func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { } // Update socket error to reflect ICMP errors in queue. - if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + if nextErr := so.PeekErr(); nextErr != nil && nextErr.Cause.Origin().IsICMPErr() { so.SetLastError(nextErr.Err) - } else if err.ErrOrigin.IsICMPErr() { + } else if err.Cause.Origin().IsICMPErr() { so.SetLastError(nil) } return err diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 97729dacc..cc535d794 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -81,10 +81,10 @@ func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { ee := linux.SockExtendedErr{ Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), - Origin: errOriginToLinux(sockErr.ErrOrigin), - Type: sockErr.ErrType, - Code: sockErr.ErrCode, - Info: sockErr.ErrInfo, + Origin: errOriginToLinux(sockErr.Cause.Origin()), + Type: sockErr.Cause.Type(), + Code: sockErr.Cause.Code(), + Info: sockErr.Cause.Info(), } switch sockErr.NetProto { diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index dab6207c0..d1778d029 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -134,8 +134,8 @@ func handleIOErrorImpl(t *kernel.Task, partialResult bool, err, intr error, op s // Similar to EPIPE. Return what we wrote this time, and let // ENOSPC be returned on the next call. return true, nil - case syserror.ECONNRESET: - // For TCP sendfile connections, we may have a reset. But we + case syserror.ECONNRESET, syserror.ETIMEDOUT: + // For TCP sendfile connections, we may have a reset or timeout. But we // should just return n as the result. return true, nil case syserror.ErrWouldBlock: diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 5f9b8e9e2..f840a4322 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,7 +16,6 @@ package header import ( "encoding/binary" - "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -208,16 +207,3 @@ func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { return ^xsum } - -// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when -// a packet having a `net` header causing an ICMP error. -func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { - switch net { - case IPv4ProtocolNumber: - return tcpip.SockExtErrorOriginICMP - case IPv6ProtocolNumber: - return tcpip.SockExtErrorOriginICMP6 - default: - panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) - } -} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 8d155344b..6a1f11a36 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -59,6 +59,14 @@ var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ PrefixLen: 120, } +type transportError struct { + origin tcpip.SockErrOrigin + typ uint8 + code uint8 + info uint32 + kind stack.TransportErrorKind +} + // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. // The former is used to pretend that it's a link endpoint so that we can // inspect packets written by the network endpoints. The latter is used to @@ -74,8 +82,7 @@ type testObject struct { srcAddr tcpip.Address dstAddr tcpip.Address v4 bool - typ stack.ControlType - extra uint32 + transErr transportError dataCalls int controlCalls int @@ -119,16 +126,23 @@ func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumb return stack.TransportPacketHandled } -// DeliverTransportControlPacket is called by network endpoints after parsing +// DeliverTransportError is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) - if typ != t.typ { - t.t.Errorf("typ = %v, want %v", typ, t.typ) - } - if extra != t.extra { - t.t.Errorf("extra = %v, want %v", extra, t.extra) + if diff := cmp.Diff( + t.transErr, + transportError{ + origin: transErr.Origin(), + typ: transErr.Type(), + code: transErr.Code(), + info: transErr.Info(), + kind: transErr.Kind(), + }, + cmp.AllowUnexported(transportError{}), + ); diff != "" { + t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) } t.controlCalls++ } @@ -702,24 +716,81 @@ func TestReceive(t *testing.T) { } func TestIPv4ReceiveControl(t *testing.T) { - const mtu = 0xbeef - header.IPv4MinimumSize + const ( + mtu = 0xbeef - header.IPv4MinimumSize + dataLen = 8 + ) + cases := []struct { name string expectedCount int fragmentOffset uint16 code header.ICMPv4Code - expectedTyp stack.ControlType - expectedExtra uint32 + transErr transportError trunc int }{ - {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, - {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, - {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, + { + name: "FragmentationNeeded", + expectedCount: 1, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4FragmentationNeeded), + info: mtu, + kind: stack.PacketTooBigTransportError, + }, + trunc: 0, + }, + { + name: "Truncated (missing IPv4 header)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, + }, + { + name: "Truncated (partial offending packet's IP header)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, + }, + { + name: "Truncated (partial offending packet's data)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, + }, + { + name: "Port unreachable", + expectedCount: 1, + fragmentOffset: 0, + code: header.ICMPv4PortUnreachable, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4PortUnreachable), + kind: stack.DestinationPortUnreachableTransportError, + }, + trunc: 0, + }, + { + name: "Non-zero fragment offset", + expectedCount: 0, + fragmentOffset: 100, + code: header.ICMPv4PortUnreachable, + trunc: 0, + }, + { + name: "Zero-length packet", + expectedCount: 0, + fragmentOffset: 100, + code: header.ICMPv4PortUnreachable, + trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -738,7 +809,7 @@ func TestIPv4ReceiveControl(t *testing.T) { } const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + 8) + view := buffer.NewView(dataOffset + dataLen) // Create the outer IPv4 header. ip := header.IPv4(view) @@ -785,8 +856,7 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.srcAddr = remoteIPv4Addr nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = view[dataOffset:] - nic.testObject.typ = c.expectedTyp - nic.testObject.extra = c.expectedExtra + nic.testObject.transErr = c.transErr addressableEndpoint, ok := ep.(stack.AddressableEndpoint) if !ok { @@ -953,30 +1023,112 @@ func TestIPv6Send(t *testing.T) { } func TestIPv6ReceiveControl(t *testing.T) { + const ( + mtu = 0xffff + outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" + dataLen = 8 + ) + newUint16 := func(v uint16) *uint16 { return &v } - const mtu = 0xffff - const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" + portUnreachableTransErr := transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6DstUnreachable), + code: uint8(header.ICMPv6PortUnreachable), + kind: stack.DestinationPortUnreachableTransportError, + } + cases := []struct { name string expectedCount int fragmentOffset *uint16 typ header.ICMPv6Type code header.ICMPv6Code - expectedTyp stack.ControlType - expectedExtra uint32 + transErr transportError trunc int }{ - {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, - {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, - {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, - {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, + { + name: "PacketTooBig", + expectedCount: 1, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6PacketTooBig), + code: uint8(header.ICMPv6UnusedCode), + info: mtu, + kind: stack.PacketTooBigTransportError, + }, + trunc: 0, + }, + { + name: "Truncated (missing offending packet's IPv6 header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, + }, + { + name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, + }, + { + name: "Truncated (partial offending packet's data)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, + }, + { + name: "Port unreachable", + expectedCount: 1, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "Truncated DstPortUnreachable (partial offending packet's IP header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, + }, + { + name: "DstPortUnreachable for Fragmented, zero offset", + expectedCount: 1, + fragmentOffset: newUint16(0), + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "DstPortUnreachable for Non-zero fragment offset", + expectedCount: 0, + fragmentOffset: newUint16(100), + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "Zero-length packet", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -998,7 +1150,7 @@ func TestIPv6ReceiveControl(t *testing.T) { if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize } - view := buffer.NewView(dataOffset + 8) + view := buffer.NewView(dataOffset + dataLen) // Create the outer IPv6 header. ip := header.IPv6(view) @@ -1049,8 +1201,7 @@ func TestIPv6ReceiveControl(t *testing.T) { nic.testObject.srcAddr = remoteIPv6Addr nic.testObject.dstAddr = localIPv6Addr nic.testObject.contents = view[dataOffset:] - nic.testObject.typ = c.expectedTyp - nic.testObject.extra = c.expectedExtra + nic.testObject.transErr = c.transErr // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3d93a2cd0..74e70e283 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -23,11 +23,108 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// icmpv4DestinationUnreachableSockError is a general ICMPv4 Destination +// Unreachable error. +// +// +stateify savable +type icmpv4DestinationUnreachableSockError struct{} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Type() uint8 { + return uint8(header.ICMPv4DstUnreachable) +} + +// Info implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Info() uint32 { + return 0 +} + +var _ stack.TransportError = (*icmpv4DestinationHostUnreachableSockError)(nil) + +// icmpv4DestinationHostUnreachableSockError is an ICMPv4 Destination Host +// Unreachable error. +// +// It indicates that a packet was not able to reach the destination host. +// +// +stateify savable +type icmpv4DestinationHostUnreachableSockError struct { + icmpv4DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4DestinationHostUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv4HostUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv4DestinationHostUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationHostUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv4DestinationPortUnreachableSockError)(nil) + +// icmpv4DestinationPortUnreachableSockError is an ICMPv4 Destination Port +// Unreachable error. +// +// It indicates that a packet reached the destination host, but the transport +// protocol was not active on the destination port. +// +// +stateify savable +type icmpv4DestinationPortUnreachableSockError struct { + icmpv4DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4DestinationPortUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv4PortUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv4DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationPortUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv4FragmentationNeededSockError)(nil) + +// icmpv4FragmentationNeededSockError is an ICMPv4 Destination Unreachable error +// due to fragmentation being required but the packet was set to not be +// fragmented. +// +// It indicates that a link exists on the path to the destination with an MTU +// that is too small to carry the packet. +// +// +stateify savable +type icmpv4FragmentationNeededSockError struct { + icmpv4DestinationUnreachableSockError + + mtu uint32 +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4FragmentationNeededSockError) Code() uint8 { + return uint8(header.ICMPv4FragmentationNeeded) +} + +// Info implements tcpip.SockErrorCause. +func (e *icmpv4FragmentationNeededSockError) Info() uint32 { + return e.mtu +} + +// Kind implements stack.TransportError. +func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind { + return stack.PacketTooBigTransportError +} + // handleControl handles the case when an ICMP error packet contains the headers // of the original packet that caused the ICMP one to be sent. This information // is used to find out which transport endpoint must be notified about the ICMP // packet. We only expect the payload, not the enclosing ICMP packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -54,10 +151,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack return } - // Skip the ip header, then deliver control message. + // Skip the ip header, then deliver the error. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -222,19 +319,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4HostUnreachable: - e.handleControl(stack.ControlNoRoute, 0, pkt) - + e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, pkt) - + e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } - e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) + e.handleControl(&icmpv4FragmentationNeededSockError{mtu: networkMTU}, pkt) } - case header.ICMPv4SrcQuench: received.srcQuench.Increment() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e146844c2..b2d626107 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -101,7 +101,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { // Use the same control type as an ICMPv4 destination host unreachable error // since the host is considered unreachable if we cannot resolve the link // address to the next hop. - e.handleControl(stack.ControlNoRoute, 0, pkt) + e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) } // NewEndpoint creates a new ipv4 endpoint. diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 12e5ead5e..dcfd93bab 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -23,11 +23,136 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// icmpv6DestinationUnreachableSockError is a general ICMPv6 Destination +// Unreachable error. +// +// +stateify savable +type icmpv6DestinationUnreachableSockError struct{} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP6 +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Type() uint8 { + return uint8(header.ICMPv6DstUnreachable) +} + +// Info implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Info() uint32 { + return 0 +} + +var _ stack.TransportError = (*icmpv6DestinationNetworkUnreachableSockError)(nil) + +// icmpv6DestinationNetworkUnreachableSockError is an ICMPv6 Destination Network +// Unreachable error. +// +// It indicates that the destination network is unreachable. +// +// +stateify savable +type icmpv6DestinationNetworkUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationNetworkUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6NetworkUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationNetworkUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationNetworkUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6DestinationPortUnreachableSockError)(nil) + +// icmpv6DestinationPortUnreachableSockError is an ICMPv6 Destination Port +// Unreachable error. +// +// It indicates that a packet reached the destination host, but the transport +// protocol was not active on the destination port. +// +// +stateify savable +type icmpv6DestinationPortUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationPortUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6PortUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationPortUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6DestinationAddressUnreachableSockError)(nil) + +// icmpv6DestinationAddressUnreachableSockError is an ICMPv6 Destination Address +// Unreachable error. +// +// It indicates that a packet was not able to reach the destination. +// +// +stateify savable +type icmpv6DestinationAddressUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationAddressUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6AddressUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationAddressUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationHostUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6PacketTooBigSockError)(nil) + +// icmpv6PacketTooBigSockError is an ICMPv6 Packet Too Big error. +// +// It indicates that a link exists on the path to the destination with an MTU +// that is too small to carry the packet. +// +// +stateify savable +type icmpv6PacketTooBigSockError struct { + mtu uint32 +} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP6 +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Type() uint8 { + return uint8(header.ICMPv6PacketTooBig) +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Code() uint8 { + return uint8(header.ICMPv6UnusedCode) +} + +// Info implements tcpip.SockErrorCause. +func (e *icmpv6PacketTooBigSockError) Info() uint32 { + return e.mtu +} + +// Kind implements stack.TransportError. +func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind { + return stack.PacketTooBigTransportError +} + // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -67,8 +192,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack p = fragHdr.TransportProtocol() } - // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -175,7 +299,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { if err != nil { networkMTU = 0 } - e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) + e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() @@ -187,11 +311,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { case header.ICMPv6NetworkUnreachable: - e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationPortUnreachableSockError{}, pkt) } - case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e56eb5796..c2e8c3ea7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -235,7 +235,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { }) pkt.NICID = e.nic.ID() pkt.NetworkProtocolNumber = ProtocolNumber - e.handleControl(stack.ControlAddressUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationAddressUnreachableSockError{}, pkt) } // onAddressAssignedLocked handles an address being assigned. diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 019d6a63c..1e00144a5 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -473,6 +473,48 @@ func (origin SockErrOrigin) IsICMPErr() bool { return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 } +// SockErrorCause is the cause of a socket error. +type SockErrorCause interface { + // Origin is the source of the error. + Origin() SockErrOrigin + + // Type is the origin specific type of error. + Type() uint8 + + // Code is the origin and type specific error code. + Code() uint8 + + // Info is any extra information about the error. + Info() uint32 +} + +// LocalSockError is a socket error that originated from the local host. +// +// +stateify savable +type LocalSockError struct { + info uint32 +} + +// Origin implements SockErrorCause. +func (*LocalSockError) Origin() SockErrOrigin { + return SockExtErrorOriginLocal +} + +// Type implements SockErrorCause. +func (*LocalSockError) Type() uint8 { + return 0 +} + +// Code implements SockErrorCause. +func (*LocalSockError) Code() uint8 { + return 0 +} + +// Info implements SockErrorCause. +func (l *LocalSockError) Info() uint32 { + return l.info +} + // SockError represents a queue entry in the per-socket error queue. // // +stateify savable @@ -481,14 +523,8 @@ type SockError struct { // Err is the error caused by the errant packet. Err Error - // ErrOrigin indicates the error origin. - ErrOrigin SockErrOrigin - // ErrType is the type in the ICMP header. - ErrType uint8 - // ErrCode is the code in the ICMP header. - ErrCode uint8 - // ErrInfo is additional info about the error. - ErrInfo uint32 + // Cause is the detailed cause of the error. + Cause SockErrorCause // Payload is the errant packet's payload. Payload []byte @@ -540,12 +576,11 @@ func (so *SocketOptions) QueueErr(err *SockError) { // QueueLocalErr queues a local error onto the local queue. func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { so.QueueErr(&SockError{ - Err: err, - ErrOrigin: SockExtErrorOriginLocal, - ErrInfo: info, - Payload: payload, - Dst: dst, - NetProto: net, + Err: err, + Cause: &LocalSockError{info: info}, + Payload: payload, + Dst: dst, + NetProto: net, }) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 693ea064a..41a489047 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -911,9 +911,8 @@ func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt } } -// DeliverTransportControlPacket delivers control packets to the appropriate -// transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { +// DeliverTransportError implements TransportDispatcher. +func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -935,7 +934,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { + if n.stack.demux.deliverError(n, net, trans, transErr, pkt, id) { return } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index e02f7190c..d589f798d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -49,31 +49,6 @@ type TransportEndpointID struct { RemoteAddress tcpip.Address } -// ControlType is the type of network control message. -type ControlType int - -// The following are the allowed values for ControlType values. -// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. -const ( - // ControlAddressUnreachable indicates that an IPv6 packet did not reach its - // destination as the destination address was unreachable. - // - // This maps to the ICMPv6 Destination Ureachable Code 3 error; see - // RFC 4443 section 3.1 for more details. - ControlAddressUnreachable ControlType = iota - ControlNetworkUnreachable - // ControlNoRoute indicates that an IPv4 packet did not reach its destination - // because the destination host was unreachable. - // - // This maps to the ICMPv4 Destination Ureachable Code 1 error; see - // RFC 791's Destination Unreachable Message section (page 4) for more - // details. - ControlNoRoute - ControlPacketTooBig - ControlPortUnreachable - ControlUnknown -) - // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { // LocalAddressBroadcast is true if the packet's local address is a broadcast @@ -81,6 +56,39 @@ type NetworkPacketInfo struct { LocalAddressBroadcast bool } +// TransportErrorKind enumerates error types that are handled by the transport +// layer. +type TransportErrorKind int + +const ( + // PacketTooBigTransportError indicates that a packet did not reach its + // destination because a link on the path to the destination had an MTU that + // was too small to carry the packet. + PacketTooBigTransportError TransportErrorKind = iota + + // DestinationHostUnreachableTransportError indicates that the destination + // host was unreachable. + DestinationHostUnreachableTransportError + + // DestinationPortUnreachableTransportError indicates that a packet reached + // the destination host, but the transport protocol was not active on the + // destination port. + DestinationPortUnreachableTransportError + + // DestinationNetworkUnreachableTransportError indicates that the destination + // network was unreachable. + DestinationNetworkUnreachableTransportError +) + +// TransportError is a marker interface for errors that may be handled by the +// transport layer. +type TransportError interface { + tcpip.SockErrorCause + + // Kind returns the type of the transport error. + Kind() TransportErrorKind +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { @@ -93,10 +101,10 @@ type TransportEndpoint interface { // HandlePacket takes ownership of the packet. HandlePacket(TransportEndpointID, *PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g. - // ICMP) packets arrive to this transport endpoint. - // HandleControlPacket takes ownership of pkt. - HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer) + // HandleError is called when the transport endpoint receives an error. + // + // HandleError takes ownership of the packet buffer. + HandleError(TransportError, *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -248,14 +256,11 @@ type TransportDispatcher interface { // DeliverTransportPacket takes ownership of the packet. DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition - // DeliverTransportControlPacket delivers control packets to the - // appropriate transport protocol endpoint. - // - // pkt.NetworkHeader must be set before calling - // DeliverTransportControlPacket. + // DeliverTransportError delivers an error to the appropriate transport + // endpoint. // - // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) + // DeliverTransportError takes ownership of the packet buffer. + DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 57ad412a1..a51d758d0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -458,6 +458,18 @@ type Stack struct { // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. receiveBufferSize ReceiveBufferSizeOption + + // tcpInvalidRateLimit is the maximal rate for sending duplicate + // acknowledgements in response to incoming TCP packets that are for an existing + // connection but that are invalid due to any of the following reasons: + // + // a) out-of-window sequence number. + // b) out-of-window acknowledgement number. + // c) PAWS check failure (when implemented). + // + // This is required to prevent potential ACK loops. + // Setting this to 0 will disable all rate limiting. + tcpInvalidRateLimit time.Duration } // UniqueID is an abstract generator of unique identifiers. @@ -668,6 +680,7 @@ func New(opts Options) *Stack { Default: DefaultBufferSize, Max: DefaultMaxBufferSize, }, + tcpInvalidRateLimit: defaultTCPInvalidRateLimit, } // Add specified network protocols. diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 8d9b20b7e..3066f4ffd 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -15,6 +15,8 @@ package stack import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" ) @@ -29,6 +31,10 @@ const ( // DefaultMaxBufferSize is the default maximum permitted size of a // send/receive buffer. DefaultMaxBufferSize = 4 << 20 // 4 MiB + + // defaultTCPInvalidRateLimit is the default value for + // stack.TCPInvalidRateLimit. + defaultTCPInvalidRateLimit = 500 * time.Millisecond ) // ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to @@ -39,6 +45,10 @@ type ReceiveBufferSizeOption struct { Max int } +// TCPInvalidRateLimitOption is used by stack.(Stack*).Option/SetOption to get/set +// stack.tcpInvalidRateLimit. +type TCPInvalidRateLimitOption time.Duration + // SetOption allows setting stack wide options. func (s *Stack) SetOption(option interface{}) tcpip.Error { switch v := option.(type) { @@ -74,6 +84,15 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error { s.mu.Unlock() return nil + case TCPInvalidRateLimitOption: + if v < 0 { + return &tcpip.ErrInvalidOptionValue{} + } + s.mu.Lock() + s.tcpInvalidRateLimit = time.Duration(v) + s.mu.Unlock() + return nil + default: return &tcpip.ErrUnknownProtocolOption{} } @@ -94,6 +113,12 @@ func (s *Stack) Option(option interface{}) tcpip.Error { s.mu.RUnlock() return nil + case *TCPInvalidRateLimitOption: + s.mu.RLock() + *v = TCPInvalidRateLimitOption(s.tcpInvalidRateLimit) + s.mu.RUnlock() + return nil + default: return &tcpip.ErrUnknownProtocolOption{} } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 375cd3080..b641a4aaa 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -138,12 +138,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket( + f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), - stack.ControlPortUnreachable, 0, pkt) + // Nothing checks the error. + nil, /* transport error */ + pkt, + ) return } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 26eceb804..7d8d0851e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -182,9 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } -// handleControlPacket delivers a control packet to the transport endpoint -// identified by id. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { +// handleError delivers an error to the transport endpoint identified by id. +func (epsByNIC *endpointsByNIC) handleError(n *NIC, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -200,7 +199,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -596,9 +595,11 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb return foundRaw } -// deliverControlPacket attempts to deliver the given control packet. Returns -// true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { +// deliverError attempts to deliver the given error to the appropriate transport +// endpoint. +// +// Returns true if the error was delivered. +func (d *transportDemuxer) deliverError(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -611,7 +612,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return false } - ep.handleControlPacket(n, id, typ, extra, pkt) + ep.handleError(n, id, transErr, pkt) return true } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index cf5de747b..bebf4e6b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -237,7 +237,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * f.acceptQueue = append(f.acceptQueue, ep) } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index b3a5d49d7..f2301a9e6 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -247,6 +247,14 @@ func TestPing(t *testing.T) { } } +type transportError struct { + origin tcpip.SockErrOrigin + typ uint8 + code uint8 + info uint32 + kind stack.TransportErrorKind +} + func TestTCPLinkResolutionFailure(t *testing.T) { const ( host1NICID = 1 @@ -259,6 +267,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr tcpip.Address expectedWriteErr tcpip.Error sockError tcpip.SockError + transErr transportError }{ { name: "IPv4 with resolvable remote", @@ -278,10 +287,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr: ipv4Addr3.AddressWithPrefix.Address, expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - ErrType: byte(header.ICMPv4DstUnreachable), - ErrCode: byte(header.ICMPv4HostUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv4Addr3.AddressWithPrefix.Address, @@ -293,6 +299,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv4.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4HostUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, { name: "IPv6 without resolvable remote", @@ -300,10 +312,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr: ipv6Addr3.AddressWithPrefix.Address, expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: &tcpip.ErrNoRoute{}, - ErrType: byte(header.ICMPv6DstUnreachable), - ErrCode: byte(header.ICMPv6AddressUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP6, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv6Addr3.AddressWithPrefix.Address, @@ -315,6 +324,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv6.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6DstUnreachable), + code: uint8(header.ICMPv6AddressUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, } @@ -393,9 +408,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // are pre defined so we can simply compare pointers. return a == b }), - // Ignore the payload since we do not know the TCP seq/ack numbers. checker.IgnoreCmpPath( + // Ignore the payload since we do not know the TCP seq/ack numbers. "Payload", + // Ignore the cause since we will compare its properties separately + // since the concrete type of the cause is unknown. + "Cause", ), } @@ -407,6 +425,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { t.Errorf("socket error mismatch (-want +got):\n%s", diff) } + + transErr, ok := sockErr.Cause.(stack.TransportError) + if !ok { + t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause) + } + if diff := cmp.Diff( + test.transErr, + transportError{ + origin: transErr.Origin(), + typ: transErr.Type(), + code: transErr.Code(), + info: transErr.Info(), + kind: transErr.Kind(), + }, + cmp.AllowUnexported(transportError{}), + ); diff != "" { + t.Errorf("socket error mismatch (-want +got):\n%s", diff) + } }) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 3cf05520d..f5e1a6e45 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -778,9 +778,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { -} +// HandleError implements stack.TransportEndpoint. +func (*endpoint) HandleError(stack.TransportError, *stack.PacketBuffer) {} // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't // expose internal socket state. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4695b66d6..34a631b53 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -333,7 +333,9 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // number and "After sending the acknowledgment, drop the unacceptable // segment and return." if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { - h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + if h.ep.allowOutOfWindowAck() { + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + } return nil } @@ -1185,8 +1187,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { // endpoint MUST terminate its connection. The local TCP endpoint // should then rely on SYN retransmission from the remote end to // re-establish the connection. - - e.snd.sendAck() + e.snd.maybeSendOutOfWindowAck(s) } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 7b1f5e763..1975f1a44 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -178,8 +178,8 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int return int(cwnd) } -// HandleNDupAcks implements congestionControl.HandleNDupAcks. -func (c *cubicState) HandleNDupAcks() { +// HandleLossDetected implements congestionControl.HandleLossDetected. +func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ c.t = time.Now() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6e4e26c39..4e5a6089f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -688,6 +688,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // lastOutOfWindowAckTime is the time at which the an ACK was sent in response + // to an out of window segment being received by this endpoint. + lastOutOfWindowAckTime time.Time `state:".(unixTime)"` } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -2683,7 +2687,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } -func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -2692,11 +2696,8 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr // Update the error queue if IP_RECVERR is enabled. if e.SocketOptions().GetRecvError() { e.SocketOptions().QueueErr(&tcpip.SockError{ - Err: err, - ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), - ErrType: errType, - ErrCode: errCode, - ErrInfo: extra, + Err: err, + Cause: transErr, // Linux passes the payload with the TCP header. We don't know if the TCP // header even exists, it may not for fragmented packets. Payload: pkt.Data.ToView(), @@ -2718,27 +2719,26 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr e.notifyProtocolGoroutine(notifyError) } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { - switch typ { - case stack.ControlPacketTooBig: +// HandleError implements stack.TransportEndpoint. +func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { + handlePacketTooBig := func(mtu uint32) { e.sndBufMu.Lock() e.packetTooBigCount++ - if v := int(extra); v < e.sndMTU { + if v := int(mtu); v < e.sndMTU { e.sndMTU = v } e.sndBufMu.Unlock() - e.notifyProtocolGoroutine(notifyMTUChanged) + } - case stack.ControlNoRoute: - e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) - - case stack.ControlAddressUnreachable: - e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) - - case stack.ControlNetworkUnreachable: - e.onICMPError(&tcpip.ErrNetworkUnreachable{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) + // TODO(gvisor.dev/issues/5270): Handle all transport errors. + switch transErr.Kind() { + case stack.PacketTooBigTransportError: + handlePacketTooBig(transErr.Info()) + case stack.DestinationHostUnreachableTransportError: + e.onICMPError(&tcpip.ErrNoRoute{}, transErr, pkt) + case stack.DestinationNetworkUnreachableTransportError: + e.onICMPError(&tcpip.ErrNetworkUnreachable{}, transErr, pkt) } } @@ -3129,3 +3129,19 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { Max: ss.Max, } } + +// allowOutOfWindowAck returns true if an out-of-window ACK can be sent now. +func (e *endpoint) allowOutOfWindowAck() bool { + var limit stack.TCPInvalidRateLimitOption + if err := e.stack.Option(&limit); err != nil { + panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) + } + + now := time.Now() + if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { + return false + } + + e.lastOutOfWindowAckTime = now + return true +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index c21dbc682..e4368026f 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -308,6 +308,16 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { e.recentTSTime = time.Unix(unix.second, unix.nano) } +// saveLastOutOfWindowAckTime is invoked by stateify. +func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { + return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()} +} + +// loadLastOutOfWindowAckTime is invoked by stateify. +func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { + e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) +} + // saveMeasureTime is invoked by stateify. func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index d85cb405a..e862f159e 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -301,7 +301,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // Step 2. Either the original packet or the retransmission (in the // form of a probe) was lost. Invoke a congestion control response // equivalent to fast recovery. - s.cc.HandleNDupAcks() + s.cc.HandleLossDetected() s.enterRecovery() s.leaveRecovery() } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 7a7c402c4..a5c82b8fa 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -385,7 +385,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // fails, we ignore the packet: // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 if r.ep.snd.sndNxt.LessThan(s.ackNumber) { - r.ep.snd.sendAck() + r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } @@ -454,7 +454,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // send an ACK and stop further processing of the segment. // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { - r.ep.snd.sendAck() + r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go index f83ebc717..ff39780a5 100644 --- a/pkg/tcpip/transport/tcp/reno.go +++ b/pkg/tcpip/transport/tcp/reno.go @@ -79,10 +79,10 @@ func (r *renoState) Update(packetsAcked int) { r.updateCongestionAvoidance(packetsAcked) } -// HandleNDupAcks implements congestionControl.HandleNDupAcks. -func (r *renoState) HandleNDupAcks() { - // A retransmit was triggered due to nDupAckThreshold - // being hit. Reduce our slow start threshold. +// HandleLossDetected implements congestionControl.HandleLossDetected. +func (r *renoState) HandleLossDetected() { + // A retransmit was triggered due to nDupAckThreshold or when RACK + // detected loss. Reduce our slow start threshold. r.reduceSlowStartThreshold() } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index dfc8fd248..463a259b7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -51,9 +51,10 @@ const ( // congestionControl is an interface that must be implemented by any supported // congestion control algorithm. type congestionControl interface { - // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold - // just before entering fast retransmit. - HandleNDupAcks() + // HandleLossDetected is invoked when the loss is detected by RACK or + // sender.dupAckCount >= nDupAckThreshold just before entering fast + // retransmit. + HandleLossDetected() // HandleRTOExpired is invoked when the retransmit timer expires. HandleRTOExpired() @@ -1152,7 +1153,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { s.dupAckCount = 0 return false } - s.cc.HandleNDupAcks() + s.cc.HandleLossDetected() s.enterRecovery() s.dupAckCount = 0 return true @@ -1548,3 +1549,13 @@ func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } + +// maybeSendOutOfWindowAck sends an ACK if we are not being rate limited +// currently. +func (s *sender) maybeSendOutOfWindowAck(seg *segment) { + // Data packets are unlikely to be part of an ACK loop. So always send + // an ACK for a packet w/ data. + if seg.payloadSize() > 0 || s.ep.allowOutOfWindowAck() { + s.sendAck() + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index da2730e27..cd3c4a027 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6302,6 +6302,13 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Enable Auto-tuning. stk := c.Stack() + // Disable out of window rate limiting for this test by setting it to 0 as we + // use out of window ACKs to measure the advertised window. + var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption + if err := stk.SetOption(tcpInvalidRateLimit); err != nil { + t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) + } + const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 31a5ddce9..afd8f4d39 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1322,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -1338,12 +1338,9 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr } e.SocketOptions().QueueErr(&tcpip.SockError{ - Err: err, - ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), - ErrType: errType, - ErrCode: errCode, - ErrInfo: extra, - Payload: payload, + Err: err, + Cause: transErr, + Payload: payload, Dst: tcpip.FullAddress{ NIC: pkt.NICID, Addr: e.ID.RemoteAddress, @@ -1362,24 +1359,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extr e.waiterQueue.Notify(waiter.EventErr) } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { - if typ == stack.ControlPortUnreachable { +// HandleError implements stack.TransportEndpoint. +func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { + // TODO(gvisor.dev/issues/5270): Handle all transport errors. + switch transErr.Kind() { + case stack.DestinationPortUnreachableTransportError: if e.EndpointState() == StateConnected { - var errType byte - var errCode byte - switch pkt.NetworkProtocolNumber { - case header.IPv4ProtocolNumber: - errType = byte(header.ICMPv4DstUnreachable) - errCode = byte(header.ICMPv4PortUnreachable) - case header.IPv6ProtocolNumber: - errType = byte(header.ICMPv6DstUnreachable) - errCode = byte(header.ICMPv6PortUnreachable) - default: - panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) - } - e.onICMPError(&tcpip.ErrConnectionRefused{}, errType, errCode, extra, pkt) - return + e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) } } } |