diff options
526 files changed, 17598 insertions, 8351 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml index 3bc5041c0..c1b478dc3 100644 --- a/.buildkite/pipeline.yaml +++ b/.buildkite/pipeline.yaml @@ -55,6 +55,9 @@ steps: # Basic unit tests. - <<: *common + label: ":golang: Nogo tests" + command: make nogo-tests + - <<: *common label: ":test_tube: Unit tests" command: make unit-tests - <<: *common @@ -69,9 +72,6 @@ steps: # Integration tests. - <<: *common - label: ":parachute: FUSE tests" - command: make fuse-tests - - <<: *common label: ":docker: Docker tests" command: make docker-tests - <<: *common @@ -90,6 +90,9 @@ steps: label: ":person_in_lotus_position: KVM tests" command: make kvm-tests - <<: *common + label: ":weight_lifter: Fsstress test" + command: make fsstress-test + - <<: *common label: ":docker: Containerd 1.3.9 tests" command: make containerd-test-1.3.9 - <<: *common diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index b572dc94f..000000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,29 +0,0 @@ -# This workflow builds the source code, extracts nogo annotations and -# posts them to GitHub, if applicable. This leverages the fact that the -# workflow token has appropriate permissions to do so, and attempts to -# leverage the GitHub workflow caches. -name: "Build" -"on": - push: - branches: - - master - pull_request: - branches: - - master - - "feature/**" - -jobs: - default: - runs-on: ubuntu-latest - steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@0.7.0 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v2 - - run: make - - run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..." - - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ -path=bazel-out/ nogo" - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GITHUB_REPOSITORY: ${{ github.repository }} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 3a4aa22e2..a9e0a4717 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -15,7 +15,7 @@ jobs: stale-issue-label: 'stale' stale-pr-label: 'stale' exempt-issue-labels: 'exported, type: bug, type: cleanup, type: enhancement, type: process, type: proposal, type: question' - exempt-pr-labels: 'ready to pull' + exempt-pr-labels: 'ready to pull, exported' stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.' stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.' days-before-stale: 90 @@ -144,14 +144,10 @@ dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo. @$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile) @$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2) @$(call configure_noreload,$(RUNTIME)-vfs2-fuse-d,--net-raw --debug --strace --log-packets --vfs2 --fuse) + @$(call configure_noreload,$(RUNTIME)-vfs2-cgroup-d,--net-raw --debug --strace --log-packets --vfs2 --cgroupfs) @$(call reload_docker) .PHONY: dev -nogo: ## Surfaces all nogo findings. - @$(call build,--build_tag_filters nogo //...) - @$(call run,//tools/github $(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo) -.PHONY: nogo - ## ## Canonical build and test targets. ## @@ -179,12 +175,12 @@ smoke-tests: ## Runs a simple smoke test after build runsc. @$(call run,//runsc,--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true) .PHONY: smoke-tests -fuse-tests: - @$(call test,--test_tag_filters=fuse $(PARTITIONS) test/fuse/...) -.PHONY: fuse-tests +nogo-tests: + @$(call test,--build_tag_filters=nogo --test_tag_filters=nogo //:all pkg/... tools/...) +.PHONY: nogo-tests unit-tests: ## Local package unit tests in pkg/..., tools/.., etc. - @$(call test,//:all pkg/... tools/...) + @$(call test,--build_tag_filters=-nogo --test_tag_filters=-nogo //:all pkg/... tools/...) .PHONY: unit-tests runsc-tests: ## Run all tests in runsc/... @@ -192,7 +188,7 @@ runsc-tests: ## Run all tests in runsc/... .PHONY: runsc-tests tests: ## Runs all unit tests and syscall tests. -tests: unit-tests runsc-tests syscall-tests +tests: unit-tests nogo-tests runsc-tests syscall-tests .PHONY: tests integration-tests: ## Run all standard integration tests. @@ -204,6 +200,9 @@ network-tests: ## Run all networking integration tests. network-tests: iptables-tests packetdrill-tests packetimpact-tests .PHONY: network-tests +# The set of system call targets. +SYSCALL_TARGETS := test/syscalls/... test/fuse/... + syscall-%-tests: @$(call test,--test_tag_filters=runsc_$* $(PARTITIONS) test/syscalls/...) @@ -212,7 +211,8 @@ syscall-native-tests: .PHONY: syscall-native-tests syscall-tests: ## Run all system call tests. - @$(call test,$(PARTITIONS) test/syscalls/...) + @$(call test,$(PARTITIONS) $(SYSCALL_TARGETS)) +.PHONY: syscall-tests %-runtime-tests: load-runtimes_% $(RUNTIME_BIN) @$(call install_runtime,$(RUNTIME),) # Ensure flags are cleared. @@ -340,7 +340,8 @@ BENCHMARKS_FILTER := . BENCHMARKS_OPTIONS := -test.benchtime=30s BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS) BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex -BENCH_RUNTIME_ARGS ?= --vfs2 +BENCH_VFS := --vfs2 +BENCH_RUNTIME_ARGS ?= init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema. @$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)) @@ -361,13 +362,14 @@ run_benchmark = \ benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS. @$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \ - $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \ - ) true + $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS) --vfs2) && \ + $(call run_benchmark,$(PLATFORM)_vfs1,--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \ + ) true @$(call run_benchmark,runc) .PHONY: benchmark-platforms run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery. - @$(call run_benchmark,$(RUNTIME),$(BENCH_RUNTIME_ARGS)) + @$(call run_benchmark,$(RUNTIME)$(BENCH_VFS),$(BENCH_RUNTIME_ARGS) $(BENCH_VFS)) .PHONY: run-benchmark ## diff --git a/debian/BUILD b/debian/BUILD index 64aa2369a..32cc209bf 100644 --- a/debian/BUILD +++ b/debian/BUILD @@ -29,6 +29,9 @@ pkg_deb( arm64 = "arm64", ), changes = "runsc.changes", + conffiles = [ + "/etc/containerd/runsc.toml", + ], data = ":debian-data", deb = "runsc.deb", # Note that the description_file will be flatten (all newlines removed), diff --git a/g3doc/user_guide/containerd/configuration.md b/g3doc/user_guide/containerd/configuration.md index 011af3b10..a214fb0c7 100644 --- a/g3doc/user_guide/containerd/configuration.md +++ b/g3doc/user_guide/containerd/configuration.md @@ -14,6 +14,7 @@ cat <<EOF | sudo tee /etc/containerd/runsc.toml option = "value" [runsc_config] flag = "value" +EOF ``` The set of options that can be configured can be found in @@ -32,10 +33,12 @@ configuration. Here is an example: ```shell cat <<EOF | sudo tee /etc/containerd/config.toml -disabled_plugins = ["restart"] -[plugins.cri.containerd.runtimes.runsc] +version = 2 +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runc] + runtime_type = "io.containerd.runc.v2" +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc] runtime_type = "io.containerd.runsc.v1" -[plugins.cri.containerd.runtimes.runsc.options] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc.options] TypeUrl = "io.containerd.runsc.v1.options" ConfigPath = "/etc/containerd/runsc.toml" EOF @@ -56,14 +59,16 @@ a containerd configuration file that enables both options: ```shell cat <<EOF | sudo tee /etc/containerd/config.toml -disabled_plugins = ["restart"] +version = 2 [debug] level = "debug" -[plugins.linux] +[plugins."io.containerd.runtime.v1.linux"] shim_debug = true -[plugins.cri.containerd.runtimes.runsc] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runc] + runtime_type = "io.containerd.runc.v2" +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc] runtime_type = "io.containerd.runsc.v1" -[plugins.cri.containerd.runtimes.runsc.options] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc.option] TypeUrl = "io.containerd.runsc.v1.options" ConfigPath = "/etc/containerd/runsc.toml" EOF @@ -93,4 +98,5 @@ log_level = "debug" [runsc_config] debug = "true" debug-log = "/var/log/runsc/%ID%/gvisor.%COMMAND%.log" +EOF ``` diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md index 02e82eb32..c742f225c 100644 --- a/g3doc/user_guide/containerd/quick_start.md +++ b/g3doc/user_guide/containerd/quick_start.md @@ -21,10 +21,12 @@ Update `/etc/containerd/config.toml`. Make sure `containerd-shim-runsc-v1` is in ```shell cat <<EOF | sudo tee /etc/containerd/config.toml -disabled_plugins = ["restart"] -[plugins.linux] +version = 2 +[plugins."io.containerd.runtime.v1.linux"] shim_debug = true -[plugins.cri.containerd.runtimes.runsc] +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runc] + runtime_type = "io.containerd.runc.v2" +[plugins."io.containerd.grpc.v1.cri".containerd.runtimes.runsc] runtime_type = "io.containerd.runsc.v1" EOF ``` @@ -55,8 +55,6 @@ global: # Same story for underscores. - "should not use ALL_CAPS in Go names" - "should not use underscores in Go names" - # TODO(b/179817829): Upgrade to flock to v0.8.0. - - "flock.NewFlock is deprecated: Use New instead" exclude: # Generated: exempt all. - pkg/shim/runtimeoptions/runtimeoptions_cri.go @@ -91,6 +89,7 @@ analyzers: - pkg/sentry/fsimpl/gofer/filesystem.go # unsupported usage. - pkg/sentry/fsimpl/gofer/gofer.go # unsupported usage. - pkg/sentry/fsimpl/gofer/regular_file.go # unsupported usage. + - pkg/sentry/fsimpl/gofer/revalidate.go # unsupported usage. - pkg/sentry/fsimpl/gofer/special_file.go # unsupported usage. - pkg/sentry/fsimpl/gofer/symlink.go # unsupported usage. - pkg/sentry/fsimpl/overlay/copy_up.go # unsupported usage. diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index ecaeb11ac..064a54547 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -76,7 +76,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/abi", - "//pkg/binary", "//pkg/bits", "//pkg/marshal", "//pkg/marshal/primitive", @@ -88,7 +87,4 @@ go_test( size = "small", srcs = ["netfilter_test.go"], library = ":linux", - deps = [ - "//pkg/binary", - ], ) diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go index 7c9a02f20..c5713541f 100644 --- a/pkg/abi/linux/elf.go +++ b/pkg/abi/linux/elf.go @@ -106,3 +106,53 @@ const ( // NT_ARM_TLS is for ARM TLS register. NT_ARM_TLS = 0x401 ) + +// ElfHeader64 is the ELF64 file header. +// +// +marshal +type ElfHeader64 struct { + Ident [16]byte // File identification. + Type uint16 // File type. + Machine uint16 // Machine architecture. + Version uint32 // ELF format version. + Entry uint64 // Entry point. + Phoff uint64 // Program header file offset. + Shoff uint64 // Section header file offset. + Flags uint32 // Architecture-specific flags. + Ehsize uint16 // Size of ELF header in bytes. + Phentsize uint16 // Size of program header entry. + Phnum uint16 // Number of program header entries. + Shentsize uint16 // Size of section header entry. + Shnum uint16 // Number of section header entries. + Shstrndx uint16 // Section name strings section. +} + +// ElfSection64 is the ELF64 Section header. +// +// +marshal +type ElfSection64 struct { + Name uint32 // Section name (index into the section header string table). + Type uint32 // Section type. + Flags uint64 // Section flags. + Addr uint64 // Address in memory image. + Off uint64 // Offset in file. + Size uint64 // Size in bytes. + Link uint32 // Index of a related section. + Info uint32 // Depends on section type. + Addralign uint64 // Alignment in bytes. + Entsize uint64 // Size of each entry in section. +} + +// ElfProg64 is the ELF64 Program header. +// +// +marshal +type ElfProg64 struct { + Type uint32 // Entry type. + Flags uint32 // Access permission flags. + Off uint64 // File offset of contents. + Vaddr uint64 // Virtual address in memory image. + Paddr uint64 // Physical address (not used). + Filesz uint64 // Size of contents in file. + Memsz uint64 // Size of contents in memory. + Align uint64 // Alignment in memory and file. +} diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go index 1121a1a92..67706f5aa 100644 --- a/pkg/abi/linux/epoll.go +++ b/pkg/abi/linux/epoll.go @@ -14,10 +14,6 @@ package linux -import ( - "gvisor.dev/gvisor/pkg/binary" -) - // Event masks. const ( EPOLLIN = 0x1 @@ -59,4 +55,4 @@ const ( ) // SizeOfEpollEvent is the size of EpollEvent struct. -var SizeOfEpollEvent = int(binary.Size(EpollEvent{})) +var SizeOfEpollEvent = (*EpollEvent)(nil).SizeBytes() diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index e11ca2d62..1e23850a9 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -19,7 +19,6 @@ import ( "strings" "gvisor.dev/gvisor/pkg/abi" - "gvisor.dev/gvisor/pkg/binary" ) // Constants for open(2). @@ -201,7 +200,7 @@ const ( ) // SizeOfStat is the size of a Stat struct. -var SizeOfStat = binary.Size(Stat{}) +var SizeOfStat = (*Stat)(nil).SizeBytes() // Flags for statx. const ( @@ -268,7 +267,7 @@ type Statx struct { } // SizeOfStatx is the size of a Statx struct. -var SizeOfStatx = binary.Size(Statx{}) +var SizeOfStatx = (*Statx)(nil).SizeBytes() // FileMode represents a mode_t. type FileMode uint16 diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index 0d921ed6f..cad24fcc7 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -19,8 +19,10 @@ package linux // See linux/magic.h. const ( ANON_INODE_FS_MAGIC = 0x09041934 + CGROUP_SUPER_MAGIC = 0x27e0eb DEVPTS_SUPER_MAGIC = 0x00001cd1 EXT_SUPER_MAGIC = 0xef53 + FUSE_SUPER_MAGIC = 0x65735546 OVERLAYFS_SUPER_MAGIC = 0x794c7630 PIPEFS_MAGIC = 0x50495045 PROC_SUPER_MAGIC = 0x9fa0 @@ -29,7 +31,6 @@ const ( SYSFS_MAGIC = 0x62656572 TMPFS_MAGIC = 0x01021994 V9FS_MAGIC = 0x01021997 - FUSE_SUPER_MAGIC = 0x65735546 ) // Filesystem path limits, from uapi/linux/limits.h. diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go index 0faf015c7..51a39704b 100644 --- a/pkg/abi/linux/netdevice.go +++ b/pkg/abi/linux/netdevice.go @@ -14,8 +14,6 @@ package linux -import "gvisor.dev/gvisor/pkg/binary" - const ( // IFNAMSIZ is the size of the name field for IFReq. IFNAMSIZ = 16 @@ -66,7 +64,7 @@ func (ifr *IFReq) SetName(name string) { } // SizeOfIFReq is the binary size of an IFReq struct (40 bytes). -var SizeOfIFReq = binary.Size(IFReq{}) +var SizeOfIFReq = (*IFReq)(nil).SizeBytes() // IFMap contains interface hardware parameters. type IFMap struct { diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 378f1baf3..3fd05483a 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -145,13 +145,13 @@ func (ke *KernelIPTEntry) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalBytes(dst) + ke.Entry.MarshalUnsafe(dst) ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalBytes(src) + ke.Entry.UnmarshalUnsafe(src) ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } @@ -245,6 +245,8 @@ const SizeOfXTCounters = 16 // include/uapi/linux/netfilter/x_tables.h. That struct contains a union // exposing different data to the user and kernel, but this struct holds only // the user data. +// +// +marshal type XTEntryMatch struct { MatchSize uint16 Name ExtensionName @@ -284,6 +286,8 @@ const SizeOfXTGetRevision = 30 // include/uapi/linux/netfilter/x_tables.h. That struct contains a union // exposing different data to the user and kernel, but this struct holds only // the user data. +// +// +marshal type XTEntryTarget struct { TargetSize uint16 Name ExtensionName @@ -306,6 +310,8 @@ type KernelXTEntryTarget struct { // XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, // RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTStandardTarget struct { Target XTEntryTarget // A positive verdict indicates a jump, and is the offset from the @@ -322,6 +328,8 @@ const SizeOfXTStandardTarget = 40 // beginning of user-defined chains by putting the name of the chain in // ErrorName. It corresponds to struct xt_error_target in // include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTErrorTarget struct { Target XTEntryTarget Name ErrorName @@ -349,6 +357,8 @@ const ( // NfNATIPV4Range corresponds to struct nf_nat_ipv4_range // in include/uapi/linux/netfilter/nf_nat.h. The fields are in // network byte order. +// +// +marshal type NfNATIPV4Range struct { Flags uint32 MinIP [4]byte @@ -359,6 +369,8 @@ type NfNATIPV4Range struct { // NfNATIPV4MultiRangeCompat corresponds to struct // nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h. +// +// +marshal type NfNATIPV4MultiRangeCompat struct { RangeSize uint32 RangeIPV4 NfNATIPV4Range @@ -366,6 +378,8 @@ type NfNATIPV4MultiRangeCompat struct { // XTRedirectTarget triggers a redirect when reached. // Adding 4 bytes of padding to make the struct 8 byte aligned. +// +// +marshal type XTRedirectTarget struct { Target XTEntryTarget NfRange NfNATIPV4MultiRangeCompat @@ -375,6 +389,19 @@ type XTRedirectTarget struct { // SizeOfXTRedirectTarget is the size of an XTRedirectTarget. const SizeOfXTRedirectTarget = 56 +// XTSNATTarget triggers Source NAT when reached. +// Adding 4 bytes of padding to make the struct 8 byte aligned. +// +// +marshal +type XTSNATTarget struct { + Target XTEntryTarget + NfRange NfNATIPV4MultiRangeCompat + _ [4]byte +} + +// SizeOfXTSNATTarget is the size of an XTSNATTarget. +const SizeOfXTSNATTarget = 56 + // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. // @@ -429,7 +456,7 @@ func (ke *KernelIPTGetEntries) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalBytes(dst) + ke.IPTGetEntries.MarshalUnsafe(dst) marshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) @@ -439,7 +466,7 @@ func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalBytes(src) + ke.IPTGetEntries.UnmarshalUnsafe(src) unmarshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) @@ -452,6 +479,8 @@ var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTReplace struct { Name TableName ValidHooks uint32 @@ -491,6 +520,8 @@ func (tn TableName) String() string { // ErrorName holds the name of a netfilter error. These can also hold // user-defined chains. +// +// +marshal type ErrorName [XT_FUNCTION_MAXNAMELEN]byte // String implements fmt.Stringer. @@ -509,6 +540,8 @@ func goString(cstring []byte) string { // XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp // in include/uapi/linux/netfilter/xt_tcpudp.h. +// +// +marshal type XTTCP struct { // SourcePortStart specifies the inclusive start of the range of source // ports to which the matcher applies. @@ -562,6 +595,8 @@ const ( // XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp // in include/uapi/linux/netfilter/xt_tcpudp.h. +// +// +marshal type XTUDP struct { // SourcePortStart is the inclusive start of the range of source ports // to which the matcher applies. @@ -602,6 +637,8 @@ const ( // IPTOwnerInfo holds data for matching packets with owner. It corresponds // to struct ipt_owner_info in libxt_owner.c of iptables binary. +// +// +marshal type IPTOwnerInfo struct { // UID is user id which created the packet. UID uint32 @@ -623,7 +660,7 @@ type IPTOwnerInfo struct { Match uint8 // Invert flips the meaning of Match field. - Invert uint8 + Invert uint8 `marshal:"unaligned"` } // SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo. diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index b953e62dc..b088b207c 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -86,7 +86,7 @@ func (ke *KernelIP6TGetEntries) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalBytes(dst) + ke.IPTGetEntries.MarshalUnsafe(dst) marshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) @@ -96,7 +96,7 @@ func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalBytes(src) + ke.IPTGetEntries.UnmarshalUnsafe(src) unmarshalledUntil := ke.IPTGetEntries.SizeBytes() for i := range ke.Entrytable { ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) @@ -149,8 +149,8 @@ type IP6TEntry struct { const SizeOfIP6TEntry = 168 // KernelIP6TEntry is identical to IP6TEntry, but includes the Elems field. -// KernelIP6TEntry itself is not Marshallable but it implements some methods of -// marshal.Marshallable that help in other implementations of Marshallable. +// +// +marshal dynamic type KernelIP6TEntry struct { Entry IP6TEntry @@ -168,13 +168,13 @@ func (ke *KernelIP6TEntry) SizeBytes() int { // MarshalBytes implements marshal.Marshallable.MarshalBytes. func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalBytes(dst) + ke.Entry.MarshalUnsafe(dst) ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalBytes(src) + ke.Entry.UnmarshalUnsafe(src) ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } @@ -264,6 +264,8 @@ const ( // NFNATRange corresponds to struct nf_nat_range in // include/uapi/linux/netfilter/nf_nat.h. +// +// +marshal type NFNATRange struct { Flags uint32 MinAddr Inet6Addr diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index bf73271c6..600820a0b 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -15,9 +15,8 @@ package linux import ( + "encoding/binary" "testing" - - "gvisor.dev/gvisor/pkg/binary" ) func TestSizes(t *testing.T) { @@ -42,7 +41,7 @@ func TestSizes(t *testing.T) { } for _, tc := range testCases { - if calculated := binary.Size(tc.typ); calculated != tc.defined { + if calculated := uintptr(binary.Size(tc.typ)); calculated != tc.defined { t.Errorf("%T has a defined size of %d and calculated size of %d", tc.typ, tc.defined, calculated) } } diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index b41f94a69..232fee67e 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -53,6 +53,8 @@ type SockAddrNetlink struct { const SockAddrNetlinkSize = 12 // NetlinkMessageHeader is struct nlmsghdr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkMessageHeader struct { Length uint32 Type uint16 @@ -99,6 +101,8 @@ const NLMSG_ALIGNTO = 4 // NetlinkAttrHeader is the header of a netlink attribute, followed by payload. // // This is struct nlattr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkAttrHeader struct { Length uint16 Type uint16 @@ -126,6 +130,8 @@ const ( ) // NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h. +// +// +marshal type NetlinkErrorMessage struct { Error int32 Header NetlinkMessageHeader diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index ceda0a8d3..581a11b24 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -85,6 +85,8 @@ const ( ) // InterfaceInfoMessage is struct ifinfomsg, from uapi/linux/rtnetlink.h. +// +// +marshal type InterfaceInfoMessage struct { Family uint8 _ uint8 @@ -164,6 +166,8 @@ const ( ) // InterfaceAddrMessage is struct ifaddrmsg, from uapi/linux/if_addr.h. +// +// +marshal type InterfaceAddrMessage struct { Family uint8 PrefixLen uint8 @@ -193,6 +197,8 @@ const ( ) // RouteMessage is struct rtmsg, from uapi/linux/rtnetlink.h. +// +// +marshal type RouteMessage struct { Family uint8 DstLen uint8 diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go index 50e22fe7e..e722971f1 100644 --- a/pkg/abi/linux/ptrace_amd64.go +++ b/pkg/abi/linux/ptrace_amd64.go @@ -61,3 +61,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 { func (p *PtraceRegs) StackPointer() uint64 { return p.Rsp } + +// SetStackPointer sets the stack pointer to the specified value. +func (p *PtraceRegs) SetStackPointer(sp uint64) { + p.Rsp = sp +} diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go index da36811d2..3d0906565 100644 --- a/pkg/abi/linux/ptrace_arm64.go +++ b/pkg/abi/linux/ptrace_arm64.go @@ -38,3 +38,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 { func (p *PtraceRegs) StackPointer() uint64 { return p.Sp } + +// SetStackPointer sets the stack pointer to the specified value. +func (p *PtraceRegs) SetStackPointer(sp uint64) { + p.Sp = sp +} diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 185eee0bb..95871b8a5 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -15,7 +15,6 @@ package linux import ( - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal" ) @@ -251,18 +250,24 @@ type SockAddrInet struct { } // Inet6MulticastRequest is struct ipv6_mreq, from uapi/linux/in6.h. +// +// +marshal type Inet6MulticastRequest struct { MulticastAddr Inet6Addr InterfaceIndex int32 } // InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h. +// +// +marshal type InetMulticastRequest struct { MulticastAddr InetAddr InterfaceAddr InetAddr } // InetMulticastRequestWithNIC is struct ip_mreqn, from uapi/linux/in.h. +// +// +marshal type InetMulticastRequestWithNIC struct { InetMulticastRequest InterfaceIndex int32 @@ -491,7 +496,7 @@ type TCPInfo struct { } // SizeOfTCPInfo is the binary size of a TCPInfo struct. -var SizeOfTCPInfo = int(binary.Size(TCPInfo{})) +var SizeOfTCPInfo = (*TCPInfo)(nil).SizeBytes() // Control message types, from linux/socket.h. const ( @@ -502,6 +507,8 @@ const ( // A ControlMessageHeader is the header for a socket control message. // // ControlMessageHeader represents struct cmsghdr from linux/socket.h. +// +// +marshal type ControlMessageHeader struct { Length uint64 Level int32 @@ -510,7 +517,7 @@ type ControlMessageHeader struct { // SizeOfControlMessageHeader is the binary size of a ControlMessageHeader // struct. -var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) +var SizeOfControlMessageHeader = (*ControlMessageHeader)(nil).SizeBytes() // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // @@ -527,6 +534,7 @@ type ControlMessageCredentials struct { // // ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. // +// +marshal // +stateify savable type ControlMessageIPPacketInfo struct { NIC int32 @@ -536,7 +544,7 @@ type ControlMessageIPPacketInfo struct { // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. -var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{})) +var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes() // A ControlMessageRights is an SCM_RIGHTS socket control message. type ControlMessageRights []int32 diff --git a/pkg/bits/bits.go b/pkg/bits/bits.go index a26433ad6..d16448c3d 100644 --- a/pkg/bits/bits.go +++ b/pkg/bits/bits.go @@ -14,3 +14,13 @@ // Package bits includes all bit related types and operations. package bits + +// AlignUp rounds a length up to an alignment. align must be a power of 2. +func AlignUp(length int, align uint) int { + return (length + int(align) - 1) & ^(int(align) - 1) +} + +// AlignDown rounds a length down to an alignment. align must be a power of 2. +func AlignDown(length int, align uint) int { + return length & ^(int(align) - 1) +} diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index 2a6977f85..c17390522 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -26,6 +26,7 @@ go_test( library = ":bpf", deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/hostarch", + "//pkg/marshal", ], ) diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go index c85d786b9..f64a2dc50 100644 --- a/pkg/bpf/interpreter_test.go +++ b/pkg/bpf/interpreter_test.go @@ -15,10 +15,12 @@ package bpf import ( + "encoding/binary" "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" ) func TestCompilationErrors(t *testing.T) { @@ -750,29 +752,29 @@ func TestSimpleFilter(t *testing.T) { // desc is the test's description. desc string - // seccompData is the input data. - seccompData + // SeccompData is the input data. + data linux.SeccompData // expectedRet is the expected return value of the BPF program. expectedRet uint32 }{ { desc: "Invalid arch is rejected", - seccompData: seccompData{nr: 1 /* x86 exit */, arch: 0x40000003 /* AUDIT_ARCH_I386 */}, + data: linux.SeccompData{Nr: 1 /* x86 exit */, Arch: 0x40000003 /* AUDIT_ARCH_I386 */}, expectedRet: 0, }, { desc: "Disallowed syscall is rejected", - seccompData: seccompData{nr: 105 /* __NR_setuid */, arch: 0xc000003e}, + data: linux.SeccompData{Nr: 105 /* __NR_setuid */, Arch: 0xc000003e}, expectedRet: 0, }, { desc: "Allowed syscall is indeed allowed", - seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e}, + data: linux.SeccompData{Nr: 231 /* __NR_exit_group */, Arch: 0xc000003e}, expectedRet: 0x7fff0000, }, } { - ret, err := Exec(p, test.seccompData.asInput()) + ret, err := Exec(p, dataAsInput(&test.data)) if err != nil { t.Errorf("%s: expected return value of %d, got execution error: %v", test.desc, test.expectedRet, err) continue @@ -792,6 +794,6 @@ type seccompData struct { } // asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() Input { - return InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +func dataAsInput(data *linux.SeccompData) Input { + return InputBytes{marshal.Marshal(data), hostarch.ByteOrder} } diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index 1f75319a7..70018cf18 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -6,10 +6,7 @@ go_library( name = "compressio", srcs = ["compressio.go"], visibility = ["//:sandbox"], - deps = [ - "//pkg/binary", - "//pkg/sync", - ], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index b094c5662..615d7f134 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -48,12 +48,12 @@ import ( "compress/flate" "crypto/hmac" "crypto/sha256" + "encoding/binary" "errors" "hash" "io" "runtime" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sync" ) @@ -130,6 +130,10 @@ type worker struct { hashPool *hashPool input chan *chunk output chan result + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } // work is the main work routine; see worker. @@ -167,7 +171,8 @@ func (w *worker) work(compress bool, level int) { // Write the hash, if enabled. if h != nil { - binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) + h.Write(w.scratch[:4]) c.h = h h = nil } @@ -175,7 +180,8 @@ func (w *worker) work(compress bool, level int) { // Check the hash of the compressed contents. if h != nil { h.Write(c.compressed.Bytes()) - binary.WriteUint32(h, binary.BigEndian, uint32(c.compressed.Len())) + binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) + h.Write(w.scratch[:4]) io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) sum := h.Sum(nil) @@ -352,6 +358,10 @@ type Reader struct { // in is the source. in io.Reader + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } var _ io.Reader = (*Reader)(nil) @@ -368,14 +378,15 @@ func NewReader(in io.Reader, key []byte) (*Reader, error) { // Use double buffering for read. r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) - var err error - if r.chunkSize, err = binary.ReadUint32(in, binary.BigEndian); err != nil { + if _, err := io.ReadFull(in, r.scratch[:4]); err != nil { return nil, err } + r.chunkSize = binary.BigEndian.Uint32(r.scratch[:4]) if r.hashPool != nil { h := r.hashPool.getHash() - binary.WriteUint32(h, binary.BigEndian, r.chunkSize) + binary.BigEndian.PutUint32(r.scratch[:], r.chunkSize) + h.Write(r.scratch[:4]) r.lastSum = h.Sum(nil) r.hashPool.putHash(h) sum := make([]byte, len(r.lastSum)) @@ -467,8 +478,7 @@ func (r *Reader) Read(p []byte) (int, error) { // reader. The length is used to limit the reader. // // See writer.flush. - l, err := binary.ReadUint32(r.in, binary.BigEndian) - if err != nil { + if _, err := io.ReadFull(r.in, r.scratch[:4]); err != nil { // This is generally okay as long as there // are still buffers outstanding. We actually // just wait for completion of those buffers here @@ -488,6 +498,7 @@ func (r *Reader) Read(p []byte) (int, error) { return done, err } } + l := binary.BigEndian.Uint32(r.scratch[:4]) // Read this chunk and schedule decompression. compressed := bufPool.Get().(*bytes.Buffer) @@ -573,6 +584,10 @@ type Writer struct { // closed indicates whether the file has been closed. closed bool + + // scratch is a temporary buffer used for marshalling. This is declared + // unfront here to avoid reallocation. + scratch [4]byte } var _ io.Writer = (*Writer)(nil) @@ -594,13 +609,15 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, } w.init(key, 1+runtime.GOMAXPROCS(0), true, level) - if err := binary.WriteUint32(w.out, binary.BigEndian, chunkSize); err != nil { + binary.BigEndian.PutUint32(w.scratch[:], chunkSize) + if _, err := w.out.Write(w.scratch[:4]); err != nil { return nil, err } if w.hashPool != nil { h := w.hashPool.getHash() - binary.WriteUint32(h, binary.BigEndian, chunkSize) + binary.BigEndian.PutUint32(w.scratch[:], chunkSize) + h.Write(w.scratch[:4]) w.lastSum = h.Sum(nil) w.hashPool.putHash(h) if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { @@ -616,7 +633,9 @@ func (w *Writer) flush(c *chunk) error { // Prefix each chunk with a length; this allows the reader to safely // limit reads while buffering. l := uint32(c.compressed.Len()) - if err := binary.WriteUint32(w.out, binary.BigEndian, l); err != nil { + + binary.BigEndian.PutUint32(w.scratch[:], l) + if _, err := w.out.Write(w.scratch[:4]); err != nil { return err } diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index a6778a005..b33a20802 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -26,6 +26,7 @@ import ( "fmt" "io" "sort" + "sync/atomic" "testing" "gvisor.dev/gvisor/pkg/hostarch" @@ -34,12 +35,16 @@ import ( "github.com/bazelbuild/rules_go/go/tools/coverdata" ) -// coverageMu must be held while accessing coverdata.Cover. This prevents -// concurrent reads/writes from multiple threads collecting coverage data. -var coverageMu sync.RWMutex +var ( + // coverageMu must be held while accessing coverdata.Cover. This prevents + // concurrent reads/writes from multiple threads collecting coverage data. + coverageMu sync.RWMutex -// once ensures that globalData is only initialized once. -var once sync.Once + // reportOutput is the place to write out a coverage report. It should be + // closed after the report is written. It is protected by reportOutputMu. + reportOutput io.WriteCloser + reportOutputMu sync.Mutex +) // blockBitLength is the number of bits used to represent coverage block index // in a synthetic PC (the rest are used to represent the file index). Even @@ -51,12 +56,26 @@ var once sync.Once // file and every block. const blockBitLength = 16 -// KcovAvailable returns whether the kcov coverage interface is available. It is -// available as long as coverage is enabled for some files. -func KcovAvailable() bool { +// Available returns whether any coverage data is available. +func Available() bool { return len(coverdata.Cover.Blocks) > 0 } +// EnableReport sets up coverage reporting. +func EnableReport(w io.WriteCloser) { + reportOutputMu.Lock() + defer reportOutputMu.Unlock() + reportOutput = w +} + +// KcovSupported returns whether the kcov interface should be made available. +// +// If coverage reporting is on, do not turn on kcov, which will consume +// coverage data. +func KcovSupported() bool { + return (reportOutput == nil) && Available() +} + var globalData struct { // files is the set of covered files sorted by filename. It is calculated at // startup. @@ -65,6 +84,9 @@ var globalData struct { // syntheticPCs are a set of PCs calculated at startup, where the PC // at syntheticPCs[i][j] corresponds to file i, block j. syntheticPCs [][]uint64 + + // once ensures that globalData is only initialized once. + once sync.Once } // ClearCoverageData clears existing coverage data. @@ -166,7 +188,7 @@ func ConsumeCoverageData(w io.Writer) int { // InitCoverageData initializes globalData. It should be called before any kcov // data is written. func InitCoverageData() { - once.Do(func() { + globalData.once.Do(func() { // First, order all files. Then calculate synthetic PCs for every block // (using the well-defined ordering for files as well). for file := range coverdata.Cover.Blocks { @@ -185,6 +207,38 @@ func InitCoverageData() { }) } +// reportOnce ensures that a coverage report is written at most once. For a +// complete coverage report, Report should be called during the sandbox teardown +// process. Report is called from multiple places (which may overlap) so that a +// coverage report is written in different sandbox exit scenarios. +var reportOnce sync.Once + +// Report writes out a coverage report with all blocks that have been covered. +// +// TODO(b/144576401): Decide whether this should actually be in LCOV format +func Report() error { + if reportOutput == nil { + return nil + } + + var err error + reportOnce.Do(func() { + for file, counters := range coverdata.Cover.Counters { + blocks := coverdata.Cover.Blocks[file] + for i := 0; i < len(counters); i++ { + if atomic.LoadUint32(&counters[i]) > 0 { + err = writeBlock(reportOutput, file, blocks[i]) + if err != nil { + return + } + } + } + } + reportOutput.Close() + }) + return err +} + // Symbolize prints information about the block corresponding to pc. func Symbolize(out io.Writer, pc uint64) error { fileNum, blockNum := syntheticPCToIndexes(pc) @@ -196,18 +250,32 @@ func Symbolize(out io.Writer, pc uint64) error { if err != nil { return err } - writeBlock(out, pc, file, block) - return nil + return writeBlockWithPC(out, pc, file, block) } // WriteAllBlocks prints all information about all blocks along with their // corresponding synthetic PCs. -func WriteAllBlocks(out io.Writer) { +func WriteAllBlocks(out io.Writer) error { for fileNum, file := range globalData.files { for blockNum, block := range coverdata.Cover.Blocks[file] { - writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block) + if err := writeBlockWithPC(out, calculateSyntheticPC(fileNum, blockNum), file, block); err != nil { + return err + } } } + return nil +} + +func writeBlockWithPC(out io.Writer, pc uint64, file string, block testing.CoverBlock) error { + if _, err := io.WriteString(out, fmt.Sprintf("%#x\n", pc)); err != nil { + return err + } + return writeBlock(out, file, block) +} + +func writeBlock(out io.Writer, file string, block testing.CoverBlock) error { + _, err := io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1)) + return err } func calculateSyntheticPC(fileNum int, blockNum int) uint64 { @@ -239,8 +307,3 @@ func blockFromIndex(file string, i int) (testing.CoverBlock, error) { } return blocks[i], nil } - -func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) { - io.WriteString(out, fmt.Sprintf("%#x\n", pc)) - io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1)) -} diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD index 35683fe98..b4e05f922 100644 --- a/pkg/gohacks/BUILD +++ b/pkg/gohacks/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,3 +10,11 @@ go_library( stateify = False, visibility = ["//:sandbox"], ) + +go_test( + name = "gohacks_test", + size = "small", + srcs = ["gohacks_test.go"], + library = ":gohacks", + deps = ["@org_golang_x_sys//unix:go_default_library"], +) diff --git a/pkg/gohacks/gohacks_test.go b/pkg/gohacks/gohacks_test.go new file mode 100644 index 000000000..e18c8abc7 --- /dev/null +++ b/pkg/gohacks/gohacks_test.go @@ -0,0 +1,97 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gohacks + +import ( + "io/ioutil" + "math/rand" + "os" + "runtime/debug" + "testing" + + "golang.org/x/sys/unix" +) + +func randBuf(size int) []byte { + b := make([]byte, size) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +// Size of a page in bytes. Cloned from hostarch.PageSize to avoid a circular +// dependency. +const pageSize = 4096 + +func testCopy(dst, src []byte) (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + debug.SetPanicOnFault(true) + copy(dst, src) + return panicked +} + +func TestSegVOnMemmove(t *testing.T) { + // Test that SIGSEGVs received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer unix.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} + +func TestSigbusOnMemmove(t *testing.T) { + // Test that SIGBUS received by runtime.memmove when *not* doing + // CopyIn or CopyOut work gets propagated to the runtime. + const bufLen = pageSize + f, err := ioutil.TempFile("", "sigbus_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + os.Remove(f.Name()) + defer f.Close() + + a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + if err != nil { + t.Fatalf("Mmap failed: %v", err) + + } + defer unix.Munmap(a) + b := randBuf(bufLen) + + if !testCopy(b, a) { + t.Fatalf("testCopy didn't panic when it should have") + } + + if !testCopy(a, b) { + t.Fatalf("testCopy didn't panic when it should have") + } +} diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go index 10bbb1f58..374aac2b4 100644 --- a/pkg/gohacks/gohacks_unsafe.go +++ b/pkg/gohacks/gohacks_unsafe.go @@ -75,3 +75,17 @@ func StringFromImmutableBytes(bs []byte) string { // strings.Builder.String(). return *(*string)(unsafe.Pointer(&bs)) } + +// Note that go:linkname silently doesn't work if the local name is exported, +// necessitating an indirection for exported functions. + +// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>. +// +//go:nosplit +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD index 7cd89e639..7a5002176 100644 --- a/pkg/marshal/BUILD +++ b/pkg/marshal/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "marshal.go", "marshal_impl_util.go", + "util.go", ], visibility = [ "//:sandbox", diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go index eb036feae..7da450ce8 100644 --- a/pkg/marshal/marshal.go +++ b/pkg/marshal/marshal.go @@ -166,6 +166,9 @@ type Marshallable interface { // %s is the first argument to the slice clause. This directive is not supported // for newtypes on arrays. // +// Note: Partial copies are not supported for Slice API UnmarshalUnsafe and +// MarshalUnsafe. +// // The slice clause also takes an optional second argument, which must be the // value "inner": // diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 32c8ed138..6f38992b7 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -125,6 +125,81 @@ func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { var _ marshal.Marshallable = (*ByteSlice)(nil) +// The following set of functions are convenient shorthands for wrapping a +// built-in type in a marshallable primitive type. For example: +// +// func useMarshallable(m marshal.Marshallable) { ... } +// +// // Compare: +// +// buf = []byte{...} +// // useMarshallable(&primitive.ByteSlice(buf)) // Not allowed, can't address temp value. +// bufP := primitive.ByteSlice(buf) +// useMarshallable(&bufP) +// +// // Vs: +// +// useMarshallable(AsByteSlice(buf)) +// +// Note that the argument to these function escapes, so avoid using them on very +// hot code paths. But generally if a function accepts an interface as an +// argument, the argument escapes anyways. + +// AllocateInt8 returns x as a marshallable. +func AllocateInt8(x int8) marshal.Marshallable { + p := Int8(x) + return &p +} + +// AllocateUint8 returns x as a marshallable. +func AllocateUint8(x uint8) marshal.Marshallable { + p := Uint8(x) + return &p +} + +// AllocateInt16 returns x as a marshallable. +func AllocateInt16(x int16) marshal.Marshallable { + p := Int16(x) + return &p +} + +// AllocateUint16 returns x as a marshallable. +func AllocateUint16(x uint16) marshal.Marshallable { + p := Uint16(x) + return &p +} + +// AllocateInt32 returns x as a marshallable. +func AllocateInt32(x int32) marshal.Marshallable { + p := Int32(x) + return &p +} + +// AllocateUint32 returns x as a marshallable. +func AllocateUint32(x uint32) marshal.Marshallable { + p := Uint32(x) + return &p +} + +// AllocateInt64 returns x as a marshallable. +func AllocateInt64(x int64) marshal.Marshallable { + p := Int64(x) + return &p +} + +// AllocateUint64 returns x as a marshallable. +func AllocateUint64(x uint64) marshal.Marshallable { + p := Uint64(x) + return &p +} + +// AsByteSlice returns b as a marshallable. Note that this allocates a new slice +// header, but does not copy the slice contents. +func AsByteSlice(b []byte) marshal.Marshallable { + bs := ByteSlice(b) + return &bs +} + // Below, we define some convenience functions for marshalling primitive types // using the newtypes above, without requiring superfluous casts. diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/marshal/util.go index c9dc7e773..c1e5475bd 100644 --- a/pkg/tcpip/transport/tcp/rack_state.go +++ b/pkg/marshal/util.go @@ -12,18 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tcp +package marshal -import ( - "time" -) - -// saveXmitTime is invoked by stateify. -func (rc *rackControl) saveXmitTime() unixTime { - return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (rc *rackControl) loadXmitTime(unix unixTime) { - rc.xmitTime = time.Unix(unix.second, unix.nano) +// Marshal returns the serialized contents of m in a newly allocated +// byte slice. +func Marshal(m Marshallable) []byte { + buf := make([]byte, m.SizeBytes()) + m.MarshalUnsafe(buf) + return buf } diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 961bd4dcf..ac7868ad9 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -36,7 +36,6 @@ const ( ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config SHA384. func DigestSize(hashAlgorithm int) int { switch hashAlgorithm { case linux.FS_VERITY_HASH_ALG_SHA256: @@ -69,7 +68,6 @@ func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) blockSize: hostarch.PageSize, } - // TODO(b/156980949): Allow config SHA384. switch hashAlgorithms { case linux.FS_VERITY_HASH_ALG_SHA256: layout.digestSize = sha256DigestSize @@ -238,6 +236,7 @@ func Generate(params *GenerateParams) ([]byte, error) { Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, SymlinkTarget: params.SymlinkTarget, } @@ -428,8 +427,6 @@ func Verify(params *VerifyParams) (int64, error) { } // If this is the end of file, zero the remaining bytes in buf, // otherwise they are still from the previous block. - // TODO(b/162908070): Investigate possible issues with zero - // padding the data. if bytesRead < len(buf) { for j := bytesRead; j < len(buf); j++ { buf[j] = 0 diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index c9f9357de..e822fe77d 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -35,10 +35,15 @@ var ( // ErrInitializationDone indicates that the caller tried to create a // new metric after initialization. ErrInitializationDone = errors.New("metric cannot be created after initialization is complete") + + // WeirdnessMetric is a metric with fields created to track the number + // of weird occurrences such as time fallback, partial_result and + // vsyscall count. + WeirdnessMetric *Uint64Metric ) // Uint64Metric encapsulates a uint64 that represents some kind of metric to be -// monitored. +// monitored. We currently support metrics with at most one field. // // Metrics are not saved across save/restore and thus reset to zero on restore. // @@ -46,6 +51,16 @@ var ( type Uint64Metric struct { // value is the actual value of the metric. It must be accessed atomically. value uint64 + + // numFields is the number of metric fields. It is immutable once + // initialized. + numFields int + + // mu protects the below fields. + mu sync.RWMutex `state:"nosave"` + + // fields is the map of fields in the metric. + fields map[string]uint64 } var ( @@ -97,8 +112,19 @@ type customUint64Metric struct { // metadata describes the metric. It is immutable. metadata *pb.MetricMetadata - // value returns the current value of the metric. - value func() uint64 + // value returns the current value of the metric for the given set of + // fields. It takes a variadic number of field values as argument. + value func(fieldValues ...string) uint64 +} + +// Field contains the field name and allowed values for the metric which is +// used in registration of the metric. +type Field struct { + // name is the metric field name. + name string + + // allowedValues is the list of allowed values for the field. + allowedValues []string } // RegisterCustomUint64Metric registers a metric with the given name. @@ -109,7 +135,8 @@ type customUint64Metric struct { // Preconditions: // * name must be globally unique. // * Initialize/Disable have not been called. -func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error { +// * value is expected to accept exactly len(fields) arguments. +func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func(...string) uint64, fields ...Field) error { if initialized { return ErrInitializationDone } @@ -129,13 +156,25 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met }, value: value, } + + // Metrics can exist without fields. + if len(fields) > 1 { + panic("Sentry metrics support at most one field") + } + + for _, field := range fields { + allMetrics.m[name].metadata.Fields = append(allMetrics.m[name].metadata.Fields, &pb.MetricMetadata_Field{ + FieldName: field.name, + AllowedValues: field.allowedValues, + }) + } return nil } -// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics -// if it returns an error. -func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func() uint64) { - if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value); err != nil { +// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric for metrics +// without fields and panics if it returns an error. +func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) { + if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil { panic(fmt.Sprintf("Unable to register metric %q: %v", name, err)) } } @@ -144,15 +183,24 @@ func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, descript // name. // // Metrics must be statically defined (i.e., at init). -func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string) (*Uint64Metric, error) { - var m Uint64Metric - return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value) +func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string, fields ...Field) (*Uint64Metric, error) { + m := Uint64Metric{ + numFields: len(fields), + } + + if m.numFields == 1 { + m.fields = make(map[string]uint64) + for _, fieldValue := range fields[0].allowedValues { + m.fields[fieldValue] = 0 + } + } + return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value, fields...) } // MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an // error. -func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric { - m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description) +func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric { + m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...) if err != nil { panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) } @@ -169,19 +217,56 @@ func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description st return m } -// Value returns the current value of the metric. -func (m *Uint64Metric) Value() uint64 { - return atomic.LoadUint64(&m.value) +// Value returns the current value of the metric for the given set of fields. +func (m *Uint64Metric) Value(fieldValues ...string) uint64 { + if m.numFields != len(fieldValues) { + panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields)) + } + + switch m.numFields { + case 0: + return atomic.LoadUint64(&m.value) + case 1: + m.mu.RLock() + defer m.mu.RUnlock() + + fieldValue := fieldValues[0] + if _, ok := m.fields[fieldValue]; !ok { + panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue)) + } + return m.fields[fieldValue] + default: + panic("Sentry metrics do not support more than one field") + } } -// Increment increments the metric by 1. -func (m *Uint64Metric) Increment() { - atomic.AddUint64(&m.value, 1) +// Increment increments the metric field by 1. +func (m *Uint64Metric) Increment(fieldValues ...string) { + m.IncrementBy(1, fieldValues...) } // IncrementBy increments the metric by v. -func (m *Uint64Metric) IncrementBy(v uint64) { - atomic.AddUint64(&m.value, v) +func (m *Uint64Metric) IncrementBy(v uint64, fieldValues ...string) { + if m.numFields != len(fieldValues) { + panic(fmt.Sprintf("Number of fieldValues %d is not equal to the number of metric fields %d", len(fieldValues), m.numFields)) + } + + switch m.numFields { + case 0: + atomic.AddUint64(&m.value, v) + return + case 1: + fieldValue := fieldValues[0] + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.fields[fieldValue]; !ok { + panic(fmt.Sprintf("Metric does not allow to have field value %s", fieldValue)) + } + m.fields[fieldValue] += v + default: + panic("Sentry metrics do not support more than one field") + } } // metricSet holds named metrics. @@ -199,14 +284,30 @@ func makeMetricSet() metricSet { // Values returns a snapshot of all values in m. func (m *metricSet) Values() metricValues { vals := make(metricValues) + for k, v := range m.m { - vals[k] = v.value() + fields := v.metadata.GetFields() + switch len(fields) { + case 0: + vals[k] = v.value() + case 1: + values := fields[0].GetAllowedValues() + fieldsMap := make(map[string]uint64) + for _, fieldValue := range values { + fieldsMap[fieldValue] = v.value(fieldValue) + } + vals[k] = fieldsMap + default: + panic(fmt.Sprintf("Unsupported number of metric fields: %d", len(fields))) + } } return vals } -// metricValues contains a copy of the values of all metrics. -type metricValues map[string]uint64 +// metricValues contains a copy of the values of all metrics. It is a map +// with key as metric name and value can be either uint64 or map[string]uint64 +// to support metrics with one field. +type metricValues map[string]interface{} var ( // emitMu protects metricsAtLastEmit and ensures that all emitted @@ -233,14 +334,37 @@ func EmitMetricUpdate() { snapshot := allMetrics.Values() m := pb.MetricUpdate{} + // On the first call metricsAtLastEmit will be empty. Include all + // metrics then. for k, v := range snapshot { - // On the first call metricsAtLastEmit will be empty. Include - // all metrics then. - if prev, ok := metricsAtLastEmit[k]; !ok || prev != v { + prev, ok := metricsAtLastEmit[k] + switch t := v.(type) { + case uint64: + // Metric exists and value did not change. + if ok && prev.(uint64) == t { + continue + } + m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, - Value: &pb.MetricValue_Uint64Value{v}, + Value: &pb.MetricValue_Uint64Value{t}, }) + case map[string]uint64: + for fieldValue, metricValue := range t { + // Emit data on the first call only if the field + // value has been incremented. For all other + // calls, emit data if the field value has been + // changed from the previous emit. + if (!ok && metricValue == 0) || (ok && prev.(map[string]uint64)[fieldValue] == metricValue) { + continue + } + + m.Metrics = append(m.Metrics, &pb.MetricValue{ + Name: k, + FieldValues: []string{fieldValue}, + Value: &pb.MetricValue_Uint64Value{metricValue}, + }) + } } } @@ -261,3 +385,16 @@ func EmitMetricUpdate() { eventchannel.Emit(&m) } + +// CreateSentryMetrics creates the sentry metrics during kernel initialization. +func CreateSentryMetrics() { + if WeirdnessMetric != nil { + return + } + + WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result and vsyscalls invoked in the sandbox", + Field{ + name: "weirdness_type", + allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count"}, + }) +} diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto index 3cc89047d..53c8b4b50 100644 --- a/pkg/metric/metric.proto +++ b/pkg/metric/metric.proto @@ -48,6 +48,15 @@ message MetricMetadata { // units is the units of the metric value. Units units = 6; + + message Field { + string field_name = 1; + repeated string allowed_values = 2; + } + + // fields contains the metric fields. Currently a metric can have at most + // one field. + repeated Field fields = 7; } // MetricRegistration contains the metadata for all metrics that will be in @@ -66,6 +75,8 @@ message MetricValue { oneof value { uint64 uint64_value = 2; } + + repeated string field_values = 4; } // MetricUpdate contains new values for multiple distinct metrics. diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index aefd0ea5c..c71dfd460 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -59,8 +59,9 @@ func reset() { } const ( - fooDescription = "Foo!" - barDescription = "Bar Baz" + fooDescription = "Foo!" + barDescription = "Bar Baz" + counterDescription = "Counter" ) func TestInitialize(t *testing.T) { @@ -95,7 +96,7 @@ func TestInitialize(t *testing.T) { foundBar := false for _, m := range mr.Metrics { if m.Type != pb.MetricMetadata_TYPE_UINT64 { - t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_TYPE_UINT64) + t.Errorf("Metadata %+v Type got %v want pb.MetricMetadata_TYPE_UINT64", m, m.Type) } if !m.Cumulative { t.Errorf("Metadata %+v Cumulative got false want true", m) @@ -256,3 +257,88 @@ func TestEmitMetricUpdate(t *testing.T) { t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) } } + +func TestEmitMetricUpdateWithFields(t *testing.T) { + defer reset() + + field := Field{ + name: "weirdness_type", + allowedValues: []string{"weird1", "weird2"}} + + counter, err := NewUint64Metric("/weirdness", false, pb.MetricMetadata_UNITS_NONE, counterDescription, field) + if err != nil { + t.Fatalf("NewUint64Metric got err %v want nil", err) + } + + Initialize() + + // Don't care about the registration metrics. + emitter.Reset() + EmitMetricUpdate() + + // For metrics with fields, we do not emit data unless the value is + // incremented. + if len(emitter) != 0 { + t.Fatalf("EmitMetricUpdate emitted %d events want 0", len(emitter)) + } + + counter.IncrementBy(4, "weird1") + counter.Increment("weird2") + + emitter.Reset() + EmitMetricUpdate() + + if len(emitter) != 1 { + t.Fatalf("EmitMetricUpdate emitted %d events want 1", len(emitter)) + } + + update, ok := emitter[0].(*pb.MetricUpdate) + if !ok { + t.Fatalf("emitter %v got %T want pb.MetricUpdate", emitter[0], emitter[0]) + } + + if len(update.Metrics) != 2 { + t.Errorf("MetricUpdate got %d metrics want 2", len(update.Metrics)) + } + + foundWeird1 := false + foundWeird2 := false + for i := 0; i < len(update.Metrics); i++ { + m := update.Metrics[i] + + if m.Name != "/weirdness" { + t.Errorf("Metric %+v name got %q want '/weirdness'", m, m.Name) + } + if len(m.FieldValues) != 1 { + t.Errorf("MetricUpdate got %d fields want 1", len(m.FieldValues)) + } + + switch m.FieldValues[0] { + case "weird1": + uv, ok := m.Value.(*pb.MetricValue_Uint64Value) + if !ok { + t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) + } + if uv.Uint64Value != 4 { + t.Errorf("%v: Value got %v want 4", m, uv.Uint64Value) + } + foundWeird1 = true + case "weird2": + uv, ok := m.Value.(*pb.MetricValue_Uint64Value) + if !ok { + t.Errorf("%+v: value %v got %T want pb.MetricValue_Uint64Value", m, m.Value, m.Value) + } + if uv.Uint64Value != 1 { + t.Errorf("%v: Value got %v want 1", m, uv.Uint64Value) + } + foundWeird2 = true + } + } + + if !foundWeird1 { + t.Errorf("Field value weird1 not found: %+v", emitter) + } + if !foundWeird2 { + t.Errorf("Field value weird2 not found: %+v", emitter) + } +} diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 7abc82e1b..28396b0ea 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -121,6 +121,22 @@ func (c *clientFile) WalkGetAttr(components []string) ([]QID, File, AttrMask, At return rwalkgetattr.QIDs, c.client.newFile(FID(fid)), rwalkgetattr.Valid, rwalkgetattr.Attr, nil } +func (c *clientFile) MultiGetAttr(names []string) ([]FullStat, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, unix.EBADF + } + + if !versionSupportsTmultiGetAttr(c.client.version) { + return DefaultMultiGetAttr(c, names) + } + + rmultigetattr := Rmultigetattr{} + if err := c.client.sendRecv(&Tmultigetattr{FID: c.fid, Names: names}, &rmultigetattr); err != nil { + return nil, err + } + return rmultigetattr.Stats, nil +} + // StatFS implements File.StatFS. func (c *clientFile) StatFS() (FSStat, error) { if atomic.LoadUint32(&c.closed) != 0 { diff --git a/pkg/p9/file.go b/pkg/p9/file.go index c59c6a65b..97e0231d6 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -15,6 +15,8 @@ package p9 import ( + "errors" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/fd" ) @@ -72,6 +74,15 @@ type File interface { // On the server, WalkGetAttr has a read concurrency guarantee. WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) + // MultiGetAttr batches up multiple calls to GetAttr(). names is a list of + // path components similar to Walk(). If the first component name is empty, + // the current file is stat'd and included in the results. If the walk reaches + // a file that doesn't exist or not a directory, MultiGetAttr returns the + // partial result with no error. + // + // On the server, MultiGetAttr has a read concurrency guarantee. + MultiGetAttr(names []string) ([]FullStat, error) + // StatFS returns information about the file system associated with // this file. // @@ -306,6 +317,53 @@ func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { type DisallowServerCalls struct{} // Renamed implements File.Renamed. -func (*clientFile) Renamed(File, string) { +func (*DisallowServerCalls) Renamed(File, string) { panic("Renamed should not be called on the client") } + +// DefaultMultiGetAttr implements File.MultiGetAttr() on top of File. +func DefaultMultiGetAttr(start File, names []string) ([]FullStat, error) { + stats := make([]FullStat, 0, len(names)) + parent := start + mask := AttrMaskAll() + for i, name := range names { + if len(name) == 0 && i == 0 { + qid, valid, attr, err := parent.GetAttr(mask) + if err != nil { + return nil, err + } + stats = append(stats, FullStat{ + QID: qid, + Valid: valid, + Attr: attr, + }) + continue + } + qids, child, valid, attr, err := parent.WalkGetAttr([]string{name}) + if parent != start { + _ = parent.Close() + } + if err != nil { + if errors.Is(err, unix.ENOENT) { + return stats, nil + } + return nil, err + } + stats = append(stats, FullStat{ + QID: qids[0], + Valid: valid, + Attr: attr, + }) + if attr.Mode.FileType() != ModeDirectory { + // Doesn't need to continue if entry is not a dir. Including symlinks + // that cannot be followed. + _ = child.Close() + break + } + parent = child + } + if parent != start { + _ = parent.Close() + } + return stats, nil +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 58312d0cc..758e11b13 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -1421,3 +1421,31 @@ func (t *Tchannel) handle(cs *connState) message { } return rchannel } + +// handle implements handler.handle. +func (t *Tmultigetattr) handle(cs *connState) message { + for i, name := range t.Names { + if len(name) == 0 && i == 0 { + // Empty name is allowed on the first entry to indicate that the current + // FID needs to be included in the result. + continue + } + if err := checkSafeName(name); err != nil { + return newErr(err) + } + } + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(unix.EBADF) + } + defer ref.DecRef() + + var stats []FullStat + if err := ref.safelyRead(func() (err error) { + stats, err = ref.file.MultiGetAttr(t.Names) + return err + }); err != nil { + return newErr(err) + } + return &Rmultigetattr{Stats: stats} +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index cf13cbb69..2ff4694c0 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -254,8 +254,8 @@ func (r *Rwalk) decode(b *buffer) { // encode implements encoder.encode. func (r *Rwalk) encode(b *buffer) { b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2243,8 +2243,8 @@ func (r *Rwalkgetattr) encode(b *buffer) { r.Valid.encode(b) r.Attr.encode(b) b.Write16(uint16(len(r.QIDs))) - for _, q := range r.QIDs { - q.encode(b) + for i := range r.QIDs { + r.QIDs[i].encode(b) } } @@ -2552,6 +2552,80 @@ func (r *Rchannel) String() string { return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length) } +// Tmultigetattr is a multi-getattr request. +type Tmultigetattr struct { + // FID is the FID to be walked. + FID FID + + // Names are the set of names to be walked. + Names []string +} + +// decode implements encoder.decode. +func (t *Tmultigetattr) decode(b *buffer) { + t.FID = b.ReadFID() + n := b.Read16() + t.Names = t.Names[:0] + for i := 0; i < int(n); i++ { + t.Names = append(t.Names, b.ReadString()) + } +} + +// encode implements encoder.encode. +func (t *Tmultigetattr) encode(b *buffer) { + b.WriteFID(t.FID) + b.Write16(uint16(len(t.Names))) + for _, name := range t.Names { + b.WriteString(name) + } +} + +// Type implements message.Type. +func (*Tmultigetattr) Type() MsgType { + return MsgTmultigetattr +} + +// String implements fmt.Stringer. +func (t *Tmultigetattr) String() string { + return fmt.Sprintf("Tmultigetattr{FID: %d, Names: %v}", t.FID, t.Names) +} + +// Rmultigetattr is a multi-getattr response. +type Rmultigetattr struct { + // Stats are the set of FullStat returned for each of the names in the + // request. + Stats []FullStat +} + +// decode implements encoder.decode. +func (r *Rmultigetattr) decode(b *buffer) { + n := b.Read16() + r.Stats = r.Stats[:0] + for i := 0; i < int(n); i++ { + var fs FullStat + fs.decode(b) + r.Stats = append(r.Stats, fs) + } +} + +// encode implements encoder.encode. +func (r *Rmultigetattr) encode(b *buffer) { + b.Write16(uint16(len(r.Stats))) + for i := range r.Stats { + r.Stats[i].encode(b) + } +} + +// Type implements message.Type. +func (*Rmultigetattr) Type() MsgType { + return MsgRmultigetattr +} + +// String implements fmt.Stringer. +func (r *Rmultigetattr) String() string { + return fmt.Sprintf("Rmultigetattr{Stats: %v}", r.Stats) +} + const maxCacheSize = 3 // msgFactory is used to reduce allocations by caching messages for reuse. @@ -2717,6 +2791,8 @@ func init() { msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} }) msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} }) + msgRegistry.register(MsgTmultigetattr, func() message { return &Tmultigetattr{} }) + msgRegistry.register(MsgRmultigetattr, func() message { return &Rmultigetattr{} }) msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 648cf4b49..3d452a0bd 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -402,6 +402,8 @@ const ( MsgRallocate MsgType = 139 MsgTsetattrclunk MsgType = 140 MsgRsetattrclunk MsgType = 141 + MsgTmultigetattr MsgType = 142 + MsgRmultigetattr MsgType = 143 MsgTchannel MsgType = 250 MsgRchannel MsgType = 251 ) @@ -1178,3 +1180,29 @@ func (a *AllocateMode) encode(b *buffer) { } b.Write32(mask) } + +// FullStat is used in the result of a MultiGetAttr call. +type FullStat struct { + QID QID + Valid AttrMask + Attr Attr +} + +// String implements fmt.Stringer. +func (f *FullStat) String() string { + return fmt.Sprintf("FullStat{QID: %v, Valid: %v, Attr: %v}", f.QID, f.Valid, f.Attr) +} + +// decode implements encoder.decode. +func (f *FullStat) decode(b *buffer) { + f.QID.decode(b) + f.Valid.decode(b) + f.Attr.decode(b) +} + +// encode implements encoder.encode. +func (f *FullStat) encode(b *buffer) { + f.QID.encode(b) + f.Valid.encode(b) + f.Attr.encode(b) +} diff --git a/pkg/p9/version.go b/pkg/p9/version.go index 8d7168ef5..950236162 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 12 + highestSupportedVersion uint32 = 13 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -179,3 +179,9 @@ func versionSupportsListRemoveXattr(v uint32) bool { func versionSupportsTsetattrclunk(v uint32) bool { return v >= 12 } + +// versionSupportsTmultiGetAttr returns true if version v supports +// the TmultiGetAttr message. +func versionSupportsTmultiGetAttr(v uint32) bool { + return v >= 13 +} diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 6992e1de8..4aecb8007 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -30,6 +30,9 @@ import ( // RefCounter is the interface to be implemented by objects that are reference // counted. +// +// TODO(gvisor.dev/issue/1624): Get rid of most of this package and replace it +// with refsvfs2. type RefCounter interface { // IncRef increments the reference counter on the object. IncRef() @@ -181,6 +184,9 @@ func (w *WeakRef) zap() { // AtomicRefCount keeps a reference count using atomic operations and calls the // destructor when the count reaches zero. // +// Do not use AtomicRefCount for new ref-counted objects! It is deprecated in +// favor of the refsvfs2 package. +// // N.B. To allow the zero-object to be initialized, the count is offset by // 1, that is, when refCount is n, there are really n+1 references. // @@ -215,8 +221,8 @@ type AtomicRefCount struct { // LeakMode configures the leak checker. type LeakMode uint32 -// TODO(gvisor.dev/issue/1624): Simplify down to two modes once vfs1 ref -// counting is gone. +// TODO(gvisor.dev/issue/1624): Simplify down to two modes (on/off) once vfs1 +// ref counting is gone. const ( // UninitializedLeakChecking indicates that the leak checker has not yet been initialized. UninitializedLeakChecking LeakMode = iota diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD index 0377c0876..7c1a8c792 100644 --- a/pkg/refsvfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -1,3 +1,5 @@ +# TODO(gvisor.dev/issue/1624): rename this directory/package to "refs" once VFS1 +# is gone and the current refs package can be deleted. load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template") diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go index 0472eca3f..fb8984dd6 100644 --- a/pkg/refsvfs2/refs_map.go +++ b/pkg/refsvfs2/refs_map.go @@ -112,20 +112,27 @@ func logEvent(obj CheckedObject, msg string) { log.Infof("[%s %p] %s:\n%s", obj.RefType(), obj, msg, refs_vfs1.FormatStack(refs_vfs1.RecordStack())) } +// checkOnce makes sure that leak checking is only done once. DoLeakCheck is +// called from multiple places (which may overlap) to cover different sandbox +// exit scenarios. +var checkOnce sync.Once + // DoLeakCheck iterates through the live object map and logs a message for each // object. It is called once no reference-counted objects should be reachable // anymore, at which point anything left in the map is considered a leak. func DoLeakCheck() { if leakCheckEnabled() { - liveObjectsMu.Lock() - defer liveObjectsMu.Unlock() - leaked := len(liveObjects) - if leaked > 0 { - msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked) - for obj := range liveObjects { - msg += obj.LeakMessage() + "\n" + checkOnce.Do(func() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + leaked := len(liveObjects) + if leaked > 0 { + msg := fmt.Sprintf("Leak checking detected %d leaked objects:\n", leaked) + for obj := range liveObjects { + msg += obj.LeakMessage() + "\n" + } + log.Warningf(msg) } - log.Warningf(msg) - } + }) } } diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 92d2330cb..41dfd0bf9 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -250,6 +250,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { } SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point. WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. + RestoreKernelFPState() // escapes: no. Restore kernel MXCSR. return } @@ -321,3 +322,21 @@ func SetCPUIDFaulting(on bool) bool { func ReadCR2() uintptr { return readCR2() } + +// kernelMXCSR is the value of the mxcsr register in the Sentry. +// +// The MXCSR control configuration is initialized once and never changed. Look +// at src/cmd/compile/abi-internal.md in the golang sources for more details. +var kernelMXCSR uint32 + +// RestoreKernelFPState restores the Sentry floating point state. +// +//go:nosplit +func RestoreKernelFPState() { + // Restore the MXCSR control configuration. + ldmxcsr(&kernelMXCSR) +} + +func init() { + stmxcsr(&kernelMXCSR) +} diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go index 7975e5f92..21db910a2 100644 --- a/pkg/ring0/kernel_arm64.go +++ b/pkg/ring0/kernel_arm64.go @@ -65,7 +65,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer()) if switchOpts.Flush { - FlushTlbByASID(uintptr(switchOpts.UserASID)) + LocalFlushTlbByASID(uintptr(switchOpts.UserASID)) } regs := switchOpts.Registers @@ -89,3 +89,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { return } + +// RestoreKernelFPState restores the Sentry floating point state. +// +//go:nosplit +func RestoreKernelFPState() { +} diff --git a/pkg/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go index 0ec5c3bc5..3e6bb9663 100644 --- a/pkg/ring0/lib_amd64.go +++ b/pkg/ring0/lib_amd64.go @@ -61,6 +61,12 @@ func wrgsbase(addr uintptr) // wrgsmsr writes to the GS_BASE MSR. func wrgsmsr(addr uintptr) +// stmxcsr reads the MXCSR control and status register. +func stmxcsr(addr *uint32) + +// ldmxcsr writes to the MXCSR control and status register. +func ldmxcsr(addr *uint32) + // readCR2 reads the current CR2 value. func readCR2() uintptr diff --git a/pkg/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s index 2fe83568a..70a43e79e 100644 --- a/pkg/ring0/lib_amd64.s +++ b/pkg/ring0/lib_amd64.s @@ -198,3 +198,15 @@ TEXT ·rdmsr(SB),NOSPLIT,$0-16 MOVL AX, ret+8(FP) MOVL DX, ret+12(FP) RET + +// stmxcsr reads the MXCSR control and status register. +TEXT ·stmxcsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), SI + STMXCSR (SI) + RET + +// ldmxcsr writes to the MXCSR control and status register. +TEXT ·ldmxcsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), SI + LDMXCSR (SI) + RET diff --git a/pkg/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go index e44df00a6..5eabd4296 100644 --- a/pkg/ring0/lib_arm64.go +++ b/pkg/ring0/lib_arm64.go @@ -31,6 +31,9 @@ func FlushTlbByVA(addr uintptr) // FlushTlbByASID invalidates tlb by ASID/Inner-Shareable. func FlushTlbByASID(asid uintptr) +// LocalFlushTlbByASID invalidates tlb by ASID. +func LocalFlushTlbByASID(asid uintptr) + // FlushTlbAll invalidates all tlb. func FlushTlbAll() diff --git a/pkg/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s index e39b32841..69ebaf519 100644 --- a/pkg/ring0/lib_arm64.s +++ b/pkg/ring0/lib_arm64.s @@ -32,6 +32,14 @@ TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8 DSB $11 // dsb(ish) RET +TEXT ·LocalFlushTlbByASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + LSL $TLBI_ASID_SHIFT, R1, R1 + DSB $10 // dsb(ishst) + WORD $0xd5088741 // tlbi aside1, x1 + DSB $11 // dsb(ish) + RET + TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 DSB $6 // dsb(nshst) WORD $0xd508871f // __tlbi(vmalle1) diff --git a/pkg/ring0/pagetables/BUILD b/pkg/ring0/pagetables/BUILD index f8f160cc6..f855f4d42 100644 --- a/pkg/ring0/pagetables/BUILD +++ b/pkg/ring0/pagetables/BUILD @@ -84,8 +84,5 @@ go_test( ":walker_check_arm64", ], library = ":pagetables", - deps = [ - "//pkg/hostarch", - "//pkg/usermem", - ], + deps = ["//pkg/hostarch"], ) diff --git a/pkg/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 3f17fba49..9dac53c80 100644 --- a/pkg/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -322,3 +322,12 @@ func (p *PageTables) Lookup(addr hostarch.Addr, findFirst bool) (virtual hostarc func (p *PageTables) MarkReadOnlyShared() { p.readOnlyShared = true } + +// PrefaultRootTable touches the root table page to be sure that its physical +// pages are mapped. +// +//go:nosplit +//go:noinline +func (p *PageTables) PrefaultRootTable() PTE { + return p.root[0] +} diff --git a/pkg/ring0/pagetables/pagetables_arm64_test.go b/pkg/ring0/pagetables/pagetables_arm64_test.go index 69320c2fb..2514b9ac5 100644 --- a/pkg/ring0/pagetables/pagetables_arm64_test.go +++ b/pkg/ring0/pagetables/pagetables_arm64_test.go @@ -19,24 +19,24 @@ package pagetables import ( "testing" - "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/hostarch" ) func Test2MAnd4K(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a small page and a huge page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*47) - pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42) - pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47) + pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: false}, pteSize*42) + pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: false}, pmdSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, - {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, - {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}}, - {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}}, + {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: true}}, + {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: false}}, + {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: hostarch.Read, User: false}}, }) } @@ -44,12 +44,12 @@ func Test1GAnd4K(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a small page and a super page. - pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42) - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47) + pt.Map(0x400000, pteSize, MapOpts{AccessType: hostarch.ReadWrite, User: true}, pteSize*42) + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*47) checkMappings(t, pt, []mapping{ - {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}}, - {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}}, + {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: hostarch.ReadWrite, User: true}}, + {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: hostarch.Read, User: true}}, }) } @@ -57,12 +57,12 @@ func TestSplit1GPage(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a super page and knock out the middle. - pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42) + pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: hostarch.Read, User: true}, pudSize*42) pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize)) checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, - {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: hostarch.Read, User: true}}, + {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}}, }) } @@ -70,11 +70,11 @@ func TestSplit2MPage(t *testing.T) { pt := New(NewRuntimeAllocator()) // Map a huge page and knock out the middle. - pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42) + pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: hostarch.Read, User: true}, pmdSize*42) pt.Unmap(hostarch.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize)) checkMappings(t, pt, []mapping{ - {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}}, - {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}}, + {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: hostarch.Read, User: true}}, + {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: hostarch.Read, User: true}}, }) } diff --git a/pkg/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s index a0cd78f33..d513f16c9 100644 --- a/pkg/safecopy/atomic_amd64.s +++ b/pkg/safecopy/atomic_amd64.s @@ -24,12 +24,12 @@ TEXT handleSwapUint32Fault(SB), NOSPLIT, $0-24 MOVL DI, sig+20(FP) RET -// swapUint32 atomically stores new into *addr and returns (the previous *addr +// swapUint32 atomically stores new into *ptr and returns (the previous ptr* // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32) TEXT ·swapUint32(SB), NOSPLIT, $0-24 @@ -38,12 +38,18 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24 // handleSwapUint32Fault will store a different value in this address. MOVL $0, sig+20(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVL new+8(FP), AX XCHGL AX, 0(DI) MOVL AX, old+16(FP) RET +// func addrOfSwapUint32() uintptr +TEXT ·addrOfSwapUint32(SB), $0-8 + MOVQ $·swapUint32(SB), AX + MOVQ AX, ret+0(FP) + RET + // handleSwapUint64Fault returns the value stored in DI. Control is transferred // to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal // number stored in DI. @@ -54,12 +60,12 @@ TEXT handleSwapUint64Fault(SB), NOSPLIT, $0-28 MOVL DI, sig+24(FP) RET -// swapUint64 atomically stores new into *addr and returns (the previous *addr +// swapUint64 atomically stores new into *ptr and returns (the previous *ptr // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the // value of old is unspecified, and sig is the number of the signal that was // received. // -// Preconditions: addr must be aligned to a 8-byte boundary. +// Preconditions: ptr must be aligned to a 8-byte boundary. // //func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32) TEXT ·swapUint64(SB), NOSPLIT, $0-28 @@ -68,12 +74,18 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28 // handleSwapUint64Fault will store a different value in this address. MOVL $0, sig+24(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVQ new+8(FP), AX XCHGQ AX, 0(DI) MOVQ AX, old+16(FP) RET +// func addrOfSwapUint64() uintptr +TEXT ·addrOfSwapUint64(SB), $0-8 + MOVQ $·swapUint64(SB), AX + MOVQ AX, ret+0(FP) + RET + // handleCompareAndSwapUint32Fault returns the value stored in DI. Control is // transferred to it when swapUint64 below receives SIGSEGV or SIGBUS, with the // signal number stored in DI. @@ -85,11 +97,11 @@ TEXT handleCompareAndSwapUint32Fault(SB), NOSPLIT, $0-24 RET // compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns -// (the value previously stored at addr, 0). If a SIGSEGV or SIGBUS signal is +// (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is // received during the operation, the value of prev is unspecified, and sig is // the number of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32) TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 @@ -99,7 +111,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 // address. MOVL $0, sig+20(FP) - MOVQ addr+0(FP), DI + MOVQ ptr+0(FP), DI MOVL old+8(FP), AX MOVL new+12(FP), DX LOCK @@ -107,6 +119,12 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 MOVL AX, prev+16(FP) RET +// func addrOfCompareAndSwapUint32() uintptr +TEXT ·addrOfCompareAndSwapUint32(SB), $0-8 + MOVQ $·compareAndSwapUint32(SB), AX + MOVQ AX, ret+0(FP) + RET + // handleLoadUint32Fault returns the value stored in DI. Control is transferred // to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal // number stored in DI. @@ -117,11 +135,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 MOVL DI, sig+12(FP) RET -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS // signal is received, the value returned is unspecified, and sig is the number // of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) TEXT ·loadUint32(SB), NOSPLIT, $0-16 @@ -130,7 +148,13 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 // handleLoadUint32Fault will store a different value in this address. MOVL $0, sig+12(FP) - MOVQ addr+0(FP), AX + MOVQ ptr+0(FP), AX MOVL (AX), BX MOVL BX, val+8(FP) RET + +// func addrOfLoadUint32() uintptr +TEXT ·addrOfLoadUint32(SB), $0-8 + MOVQ $·loadUint32(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s index d58ed71f7..246a049ba 100644 --- a/pkg/safecopy/atomic_arm64.s +++ b/pkg/safecopy/atomic_arm64.s @@ -25,7 +25,7 @@ TEXT ·swapUint32(SB), NOSPLIT, $0-24 // handleSwapUint32Fault will store a different value in this address. MOVW $0, sig+20(FP) again: - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVW new+8(FP), R1 LDAXRW (R0), R2 STLXRW R1, (R0), R3 @@ -33,6 +33,12 @@ again: MOVW R2, old+16(FP) RET +// func addrOfSwapUint32() uintptr +TEXT ·addrOfSwapUint32(SB), $0-8 + MOVD $·swapUint32(SB), R0 + MOVD R0, ret+0(FP) + RET + // handleSwapUint64Fault returns the value stored in R1. Control is transferred // to it when swapUint64 below receives SIGSEGV or SIGBUS, with the signal // number stored in R1. @@ -54,7 +60,7 @@ TEXT ·swapUint64(SB), NOSPLIT, $0-28 // handleSwapUint64Fault will store a different value in this address. MOVW $0, sig+24(FP) again: - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVD new+8(FP), R1 LDAXR (R0), R2 STLXR R1, (R0), R3 @@ -62,6 +68,12 @@ again: MOVD R2, old+16(FP) RET +// func addrOfSwapUint64() uintptr +TEXT ·addrOfSwapUint64(SB), $0-8 + MOVD $·swapUint64(SB), R0 + MOVD R0, ret+0(FP) + RET + // handleCompareAndSwapUint32Fault returns the value stored in R1. Control is // transferred to it when compareAndSwapUint32 below receives SIGSEGV or SIGBUS, // with the signal number stored in R1. @@ -84,7 +96,7 @@ TEXT ·compareAndSwapUint32(SB), NOSPLIT, $0-24 // address. MOVW $0, sig+20(FP) - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 MOVW old+8(FP), R1 MOVW new+12(FP), R2 again: @@ -97,6 +109,12 @@ done: MOVW R3, prev+16(FP) RET +// func addrOfCompareAndSwapUint32() uintptr +TEXT ·addrOfCompareAndSwapUint32(SB), $0-8 + MOVD $·compareAndSwapUint32(SB), R0 + MOVD R0, ret+0(FP) + RET + // handleLoadUint32Fault returns the value stored in DI. Control is transferred // to it when LoadUint32 below receives SIGSEGV or SIGBUS, with the signal // number stored in DI. @@ -107,11 +125,11 @@ TEXT handleLoadUint32Fault(SB), NOSPLIT, $0-16 MOVW R1, sig+12(FP) RET -// loadUint32 atomically loads *addr and returns it. If a SIGSEGV or SIGBUS +// loadUint32 atomically loads *ptr and returns it. If a SIGSEGV or SIGBUS // signal is received, the value returned is unspecified, and sig is the number // of the signal that was received. // -// Preconditions: addr must be aligned to a 4-byte boundary. +// Preconditions: ptr must be aligned to a 4-byte boundary. // //func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) TEXT ·loadUint32(SB), NOSPLIT, $0-16 @@ -120,7 +138,13 @@ TEXT ·loadUint32(SB), NOSPLIT, $0-16 // handleLoadUint32Fault will store a different value in this address. MOVW $0, sig+12(FP) - MOVD addr+0(FP), R0 + MOVD ptr+0(FP), R0 LDARW (R0), R1 MOVW R1, val+8(FP) RET + +// func addrOfLoadUint32() uintptr +TEXT ·addrOfLoadUint32(SB), $0-8 + MOVD $·loadUint32(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s index 64cf32f05..4abaecaff 100644 --- a/pkg/safecopy/memclr_amd64.s +++ b/pkg/safecopy/memclr_amd64.s @@ -145,3 +145,9 @@ _129through256: MOVOU X0, -32(DI)(BX*1) MOVOU X0, -16(DI)(BX*1) RET + +// func addrOfMemclr() uintptr +TEXT ·addrOfMemclr(SB), $0-8 + MOVQ $·memclr(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s index 7361b9067..c789bfeb3 100644 --- a/pkg/safecopy/memclr_arm64.s +++ b/pkg/safecopy/memclr_arm64.s @@ -72,3 +72,9 @@ head_loop: CMP $16, R1 BLT tail_zero B aligned_to_16 + +// func addrOfMemclr() uintptr +TEXT ·addrOfMemclr(SB), $0-8 + MOVD $·memclr(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s index 00b46c18f..37316b2f5 100644 --- a/pkg/safecopy/memcpy_amd64.s +++ b/pkg/safecopy/memcpy_amd64.s @@ -51,8 +51,8 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36 // handleMemcpyFault will store a different value in this address. MOVL $0, sig+32(FP) - MOVQ to+0(FP), DI - MOVQ from+8(FP), SI + MOVQ dst+0(FP), DI + MOVQ src+8(FP), SI MOVQ n+16(FP), BX tail: @@ -217,3 +217,9 @@ move_129through256: MOVOU -16(SI)(BX*1), X15 MOVOU X15, -16(DI)(BX*1) RET + +// func addrOfMemcpy() uintptr +TEXT ·addrOfMemcpy(SB), $0-8 + MOVQ $·memcpy(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s index e7e541565..50f5b754b 100644 --- a/pkg/safecopy/memcpy_arm64.s +++ b/pkg/safecopy/memcpy_arm64.s @@ -33,8 +33,8 @@ TEXT ·memcpy(SB), NOSPLIT, $-8-36 // handleMemcpyFault will store a different value in this address. MOVW $0, sig+32(FP) - MOVD to+0(FP), R3 - MOVD from+8(FP), R4 + MOVD dst+0(FP), R3 + MOVD src+8(FP), R4 MOVD n+16(FP), R5 CMP $0, R5 BNE check @@ -76,3 +76,9 @@ forwardtailloop: CMP R3, R9 BNE forwardtailloop RET + +// func addrOfMemcpy() uintptr +TEXT ·addrOfMemcpy(SB), $0-8 + MOVD $·memcpy(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safecopy/safecopy.go b/pkg/safecopy/safecopy.go index 1e0af5889..df63dd5f1 100644 --- a/pkg/safecopy/safecopy.go +++ b/pkg/safecopy/safecopy.go @@ -18,7 +18,6 @@ package safecopy import ( "fmt" - "reflect" "runtime" "golang.org/x/sys/unix" @@ -91,6 +90,11 @@ var ( // signals. func signalHandler() +// addrOfSignalHandler returns the start address of signalHandler. +// +// See comment on addrOfMemcpy for more details. +func addrOfSignalHandler() uintptr + // FindEndAddress returns the end address (one byte beyond the last) of the // function that contains the specified address (begin). func FindEndAddress(begin uintptr) uintptr { @@ -111,26 +115,26 @@ func initializeAddresses() { // The following functions are written in assembly language, so they won't // be inlined by the existing compiler/linker. Tests will fail if this // assumption is violated. - memcpyBegin = reflect.ValueOf(memcpy).Pointer() + memcpyBegin = addrOfMemcpy() memcpyEnd = FindEndAddress(memcpyBegin) - memclrBegin = reflect.ValueOf(memclr).Pointer() + memclrBegin = addrOfMemclr() memclrEnd = FindEndAddress(memclrBegin) - swapUint32Begin = reflect.ValueOf(swapUint32).Pointer() + swapUint32Begin = addrOfSwapUint32() swapUint32End = FindEndAddress(swapUint32Begin) - swapUint64Begin = reflect.ValueOf(swapUint64).Pointer() + swapUint64Begin = addrOfSwapUint64() swapUint64End = FindEndAddress(swapUint64Begin) - compareAndSwapUint32Begin = reflect.ValueOf(compareAndSwapUint32).Pointer() + compareAndSwapUint32Begin = addrOfCompareAndSwapUint32() compareAndSwapUint32End = FindEndAddress(compareAndSwapUint32Begin) - loadUint32Begin = reflect.ValueOf(loadUint32).Pointer() + loadUint32Begin = addrOfLoadUint32() loadUint32End = FindEndAddress(loadUint32Begin) } func init() { initializeAddresses() - if err := ReplaceSignalHandler(unix.SIGSEGV, reflect.ValueOf(signalHandler).Pointer(), &savedSigSegVHandler); err != nil { + if err := ReplaceSignalHandler(unix.SIGSEGV, addrOfSignalHandler(), &savedSigSegVHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGSEGV: %v", err)) } - if err := ReplaceSignalHandler(unix.SIGBUS, reflect.ValueOf(signalHandler).Pointer(), &savedSigBusHandler); err != nil { + if err := ReplaceSignalHandler(unix.SIGBUS, addrOfSignalHandler(), &savedSigBusHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for SIGBUS: %v", err)) } syserror.AddErrorUnwrapper(func(e error) (unix.Errno, bool) { diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go index 611f36253..55743e69c 100644 --- a/pkg/safecopy/safecopy_test.go +++ b/pkg/safecopy/safecopy_test.go @@ -19,8 +19,6 @@ import ( "fmt" "io/ioutil" "math/rand" - "os" - "runtime/debug" "testing" "unsafe" @@ -568,63 +566,3 @@ func TestCompareAndSwapUint32BusError(t *testing.T) { } }) } - -func testCopy(dst, src []byte) (panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - debug.SetPanicOnFault(true) - copy(dst, src) - return -} - -func TestSegVOnMemmove(t *testing.T) { - // Test that SIGSEGVs received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - a, err := unix.Mmap(-1, 0, bufLen, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer unix.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} - -func TestSigbusOnMemmove(t *testing.T) { - // Test that SIGBUS received by runtime.memmove when *not* doing - // CopyIn or CopyOut work gets propagated to the runtime. - const bufLen = pageSize - f, err := ioutil.TempFile("", "sigbus_test") - if err != nil { - t.Fatalf("TempFile failed: %v", err) - } - os.Remove(f.Name()) - defer f.Close() - - a, err := unix.Mmap(int(f.Fd()), 0, bufLen, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) - if err != nil { - t.Fatalf("Mmap failed: %v", err) - - } - defer unix.Munmap(a) - b := randBuf(bufLen) - - if !testCopy(b, a) { - t.Fatalf("testCopy didn't panic when it should have") - } - - if !testCopy(a, b) { - t.Fatalf("testCopy didn't panic when it should have") - } -} diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index a075cf88e..efbc2ddc1 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -89,6 +89,18 @@ func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig //go:noescape func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32) +// Return the start address of the functions above. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfMemcpy() uintptr +func addrOfMemclr() uintptr +func addrOfSwapUint32() uintptr +func addrOfSwapUint64() uintptr +func addrOfCompareAndSwapUint32() uintptr +func addrOfLoadUint32() uintptr + // CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes // copied and an error if SIGSEGV or SIGBUS is received while reading from src. func CopyIn(dst []byte, src unsafe.Pointer) (int, error) { diff --git a/pkg/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s index 475ae48e9..0b5e8df66 100644 --- a/pkg/safecopy/sighandler_amd64.s +++ b/pkg/safecopy/sighandler_amd64.s @@ -131,3 +131,9 @@ handle_fault: MOVL DI, REG_RDI(DX) RET + +// func addrOfSignalHandler() uintptr +TEXT ·addrOfSignalHandler(SB), $0-8 + MOVQ $·signalHandler(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s index 53e4ac2c1..41ed70ff9 100644 --- a/pkg/safecopy/sighandler_arm64.s +++ b/pkg/safecopy/sighandler_arm64.s @@ -141,3 +141,9 @@ handle_fault: MOVW R0, REG_R1(R2) RET + +// func addrOfSignalHandler() uintptr +TEXT ·addrOfSignalHandler(SB), $0-8 + MOVD $·signalHandler(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD index 3fda3a9cc..2c7cc8769 100644 --- a/pkg/safemem/BUILD +++ b/pkg/safemem/BUILD @@ -14,6 +14,7 @@ go_library( deps = [ "//pkg/gohacks", "//pkg/safecopy", + "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go index 93879bb4f..4af534385 100644 --- a/pkg/safemem/block_unsafe.go +++ b/pkg/safemem/block_unsafe.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safecopy" + "gvisor.dev/gvisor/pkg/sync" ) // A Block is a range of contiguous bytes, similar to []byte but with the @@ -223,8 +224,22 @@ func Copy(dst, src Block) (int, error) { func Zero(dst Block) (int, error) { if !dst.needSafecopy { bs := dst.ToSlice() - for i := range bs { - bs[i] = 0 + if !sync.RaceEnabled { + // If the race detector isn't enabled, the golang + // compiler replaces the next loop with memclr + // (https://github.com/golang/go/issues/5373). + for i := range bs { + bs[i] = 0 + } + } else { + bsLen := len(bs) + if bsLen == 0 { + return 0, nil + } + bs[0] = 0 + for i := 1; i < bsLen; i *= 2 { + copy(bs[i:], bs[:i]) + } } return len(bs), nil } diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go index 1e9625bee..f0ba26736 100644 --- a/pkg/sentry/arch/fpu/fpu_amd64.go +++ b/pkg/sentry/arch/fpu/fpu_amd64.go @@ -219,6 +219,11 @@ func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid return copy(*s, f), nil } +// SetMXCSR sets the MXCSR control/status register in the state. +func (s *State) SetMXCSR(mxcsr uint32) { + hostarch.ByteOrder.PutUint32((*s)[mxcsrOffset:], mxcsr) +} + // BytePointer returns a pointer to the first byte of the state. // //go:nosplit diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index 1929e41cd..49c53452a 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -93,6 +93,7 @@ func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro // "/dev/zero (deleted)". opts.Offset = 0 opts.MappingIdentity = &fd.vfsfd + opts.SentryOwnedContent = true opts.MappingIdentity.IncRef() return nil } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 0b3d0617f..46a2dc47d 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -384,8 +384,16 @@ func (c *ConnectedEndpoint) CloseUnread() {} // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { - // gVisor does not permit setting of SO_SNDBUF for host backed unix domain - // sockets. + // gVisor does not permit setting of SO_SNDBUF for host backed unix + // domain sockets. + return atomic.LoadInt64(&c.sndbuf) +} + +// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize. +func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_RCVBUF for host backed unix + // domain sockets. Receive buffer does not have any effect for unix + // sockets and we claim to be the same as send buffer. return atomic.LoadInt64(&c.sndbuf) } diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD new file mode 100644 index 000000000..37efb641a --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/BUILD @@ -0,0 +1,48 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "cgroupfs", + prefix = "dir", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "dir", + }, +) + +go_library( + name = "cgroupfs", + srcs = [ + "base.go", + "cgroupfs.go", + "cpu.go", + "cpuacct.go", + "cpuset.go", + "dir_refs.go", + "job.go", + "memory.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/coverage", + "//pkg/log", + "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/usage", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go new file mode 100644 index 000000000..0f54888d8 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -0,0 +1,261 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "sort" + "strconv" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// controllerCommon implements kernel.CgroupController. +// +// Must call init before use. +// +// +stateify savable +type controllerCommon struct { + ty kernel.CgroupControllerType + fs *filesystem +} + +func (c *controllerCommon) init(ty kernel.CgroupControllerType, fs *filesystem) { + c.ty = ty + c.fs = fs +} + +// Type implements kernel.CgroupController.Type. +func (c *controllerCommon) Type() kernel.CgroupControllerType { + return kernel.CgroupControllerType(c.ty) +} + +// HierarchyID implements kernel.CgroupController.HierarchyID. +func (c *controllerCommon) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// NumCgroups implements kernel.CgroupController.NumCgroups. +func (c *controllerCommon) NumCgroups() uint64 { + return atomic.LoadUint64(&c.fs.numCgroups) +} + +// Enabled implements kernel.CgroupController.Enabled. +// +// Controllers are currently always enabled. +func (c *controllerCommon) Enabled() bool { + return true +} + +// Filesystem implements kernel.CgroupController.Filesystem. +func (c *controllerCommon) Filesystem() *vfs.Filesystem { + return c.fs.VFSFilesystem() +} + +// RootCgroup implements kernel.CgroupController.RootCgroup. +func (c *controllerCommon) RootCgroup() kernel.Cgroup { + return c.fs.rootCgroup() +} + +// controller is an interface for common functionality related to all cgroups. +// It is an extension of the public cgroup interface, containing cgroup +// functionality private to cgroupfs. +type controller interface { + kernel.CgroupController + + // AddControlFiles should extend the contents map with inodes representing + // control files defined by this controller. + AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode) +} + +// cgroupInode implements kernel.CgroupImpl and kernfs.Inode. +// +// +stateify savable +type cgroupInode struct { + dir + fs *filesystem + + // ts is the list of tasks in this cgroup. The kernel is responsible for + // removing tasks from this list before they're destroyed, so any tasks on + // this list are always valid. + // + // ts, and cgroup membership in general is protected by fs.tasksMu. + ts map[*kernel.Task]struct{} +} + +var _ kernel.CgroupImpl = (*cgroupInode)(nil) + +func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode { + c := &cgroupInode{ + fs: fs, + ts: make(map[*kernel.Task]struct{}), + } + + contents := make(map[string]kernfs.Inode) + contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c}) + contents["tasks"] = fs.newControllerFile(ctx, creds, &tasksData{c}) + + for _, ctl := range fs.controllers { + ctl.AddControlFiles(ctx, creds, c, contents) + } + + c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555)) + c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + c.dir.InitRefs() + c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents)) + + atomic.AddUint64(&fs.numCgroups, 1) + + return c +} + +func (c *cgroupInode) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// Controllers implements kernel.CgroupImpl.Controllers. +func (c *cgroupInode) Controllers() []kernel.CgroupController { + return c.fs.kcontrollers +} + +// Enter implements kernel.CgroupImpl.Enter. +func (c *cgroupInode) Enter(t *kernel.Task) { + c.fs.tasksMu.Lock() + c.ts[t] = struct{}{} + c.fs.tasksMu.Unlock() +} + +// Leave implements kernel.CgroupImpl.Leave. +func (c *cgroupInode) Leave(t *kernel.Task) { + c.fs.tasksMu.Lock() + delete(c.ts, t) + c.fs.tasksMu.Unlock() +} + +func sortTIDs(tids []kernel.ThreadID) { + sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] }) +} + +// +stateify savable +type cgroupProcsData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + pgids := make(map[kernel.ThreadID]struct{}) + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + // Map dedups pgid, since iterating over all tasks produces multiple + // entries for the group leaders. + if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 { + pgids[pgid] = struct{}{} + } + } + + pgidList := make([]kernel.ThreadID, 0, len(pgids)) + for pgid, _ := range pgids { + pgidList = append(pgidList, pgid) + } + sortTIDs(pgidList) + + for _, pgid := range pgidList { + fmt.Fprintf(buf, "%d\n", pgid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} + +// +stateify savable +type tasksData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + var pids []kernel.ThreadID + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + if pid := currPidns.IDOfTask(task); pid != 0 { + pids = append(pids, pid) + } + } + sortTIDs(pids) + + for _, pid := range pids { + fmt.Fprintf(buf, "%d\n", pid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} + +// parseInt64FromString interprets src as string encoding a int64 value, and +// returns the parsed value. +func parseInt64FromString(ctx context.Context, src usermem.IOSequence, offset int64) (val, len int64, err error) { + const maxInt64StrLen = 20 // i.e. len(fmt.Sprintf("%d", math.MinInt64)) == 20 + + t := kernel.TaskFromContext(ctx) + src = src.DropFirst64(offset) + + buf := t.CopyScratchBuffer(maxInt64StrLen) + n, err := src.CopyIn(ctx, buf) + if err != nil { + return 0, int64(n), err + } + buf = buf[:n] + + val, err = strconv.ParseInt(string(buf), 10, 64) + if err != nil { + // Note: This also handles zero-len writes if offset is beyond the end + // of src, or src is empty. + ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err) + return 0, int64(n), syserror.EINVAL + } + + return val, int64(n), nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go new file mode 100644 index 000000000..bd3e69757 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -0,0 +1,425 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cgroupfs implements cgroupfs. +// +// A cgroup is a collection of tasks on the system, organized into a tree-like +// structure similar to a filesystem directory tree. In fact, each cgroup is +// represented by a directory on cgroupfs, and is manipulated through control +// files in the directory. +// +// All cgroups on a system are organized into hierarchies. Hierarchies are a +// distinct tree of cgroups, with a common set of controllers. One or more +// cgroupfs mounts may point to each hierarchy. These mounts provide a common +// view into the same tree of cgroups. +// +// A controller (also known as a "resource controller", or a cgroup "subsystem") +// determines the behaviour of each cgroup. +// +// In addition to cgroupfs, the kernel has a cgroup registry that tracks +// system-wide state related to cgroups such as active hierarchies and the +// controllers associated with them. +// +// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between +// cgroupfs dentries and inodes. +// +// # Synchronization +// +// Cgroup hierarchy creation and destruction is protected by the +// kernel.CgroupRegistry.mu. Once created, a hierarchy's set of controllers, the +// filesystem associated with it, and the root cgroup for the hierarchy are +// immutable. +// +// Membership of tasks within cgroups is protected by +// cgroupfs.filesystem.tasksMu. Tasks also maintain a set of all cgroups they're +// in, and this list is protected by Task.mu. +// +// Lock order: +// +// kernel.CgroupRegistry.mu +// cgroupfs.filesystem.mu +// Task.mu +// cgroupfs.filesystem.tasksMu. +package cgroupfs + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + // Name is the default filesystem name. + Name = "cgroup" + readonlyFileMode = linux.FileMode(0444) + writableFileMode = linux.FileMode(0644) + defaultMaxCachedDentries = uint64(1000) +) + +const ( + controllerCPU = kernel.CgroupControllerType("cpu") + controllerCPUAcct = kernel.CgroupControllerType("cpuacct") + controllerCPUSet = kernel.CgroupControllerType("cpuset") + controllerJob = kernel.CgroupControllerType("job") + controllerMemory = kernel.CgroupControllerType("memory") +) + +var allControllers = []kernel.CgroupControllerType{ + controllerCPU, + controllerCPUAcct, + controllerCPUSet, + controllerJob, + controllerMemory, +} + +// SupportedMountOptions is the set of supported mount options for cgroupfs. +var SupportedMountOptions = []string{"all", "cpu", "cpuacct", "cpuset", "job", "memory"} + +// FilesystemType implements vfs.FilesystemType. +// +// +stateify savable +type FilesystemType struct{} + +// InternalData contains internal data passed in to the cgroupfs mount via +// vfs.GetFilesystemOptions.InternalData. +// +// +stateify savable +type InternalData struct { + DefaultControlValues map[string]int64 +} + +// filesystem implements vfs.FilesystemImpl. +// +// +stateify savable +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // hierarchyID is the id the cgroup registry assigns to this hierarchy. Has + // the value kernel.InvalidCgroupHierarchyID until the FS is fully + // initialized. + // + // hierarchyID is immutable after initialization. + hierarchyID uint32 + + // controllers and kcontrollers are both the list of controllers attached to + // this cgroupfs. Both lists are the same set of controllers, but typecast + // to different interfaces for convenience. Both must stay in sync, and are + // immutable. + controllers []controller + kcontrollers []kernel.CgroupController + + numCgroups uint64 // Protected by atomic ops. + + root *kernfs.Dentry + + // tasksMu serializes task membership changes across all cgroups within a + // filesystem. + tasksMu sync.RWMutex `state:"nosave"` +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// Release implements vfs.FilesystemType.Release. +func (FilesystemType) Release(ctx context.Context) {} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + + var wantControllers []kernel.CgroupControllerType + if _, ok := mopts["cpu"]; ok { + delete(mopts, "cpu") + wantControllers = append(wantControllers, controllerCPU) + } + if _, ok := mopts["cpuacct"]; ok { + delete(mopts, "cpuacct") + wantControllers = append(wantControllers, controllerCPUAcct) + } + if _, ok := mopts["cpuset"]; ok { + delete(mopts, "cpuset") + wantControllers = append(wantControllers, controllerCPUSet) + } + if _, ok := mopts["job"]; ok { + delete(mopts, "job") + wantControllers = append(wantControllers, controllerJob) + } + if _, ok := mopts["memory"]; ok { + delete(mopts, "memory") + wantControllers = append(wantControllers, controllerMemory) + } + if _, ok := mopts["all"]; ok { + if len(wantControllers) > 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers) + return nil, nil, syserror.EINVAL + } + + delete(mopts, "all") + wantControllers = allControllers + } + + if len(wantControllers) == 0 { + // Specifying no controllers implies all controllers. + wantControllers = allControllers + } + + if len(mopts) != 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + // "It is not possible to mount the same controller against multiple + // cgroup hierarchies. For example, it is not possible to mount both + // the cpu and cpuacct controllers against one hierarchy, and to mount + // the cpu controller alone against another hierarchy." - man cgroups(7) + // + // Is there a hierarchy available with all the controllers we want? If so, + // this mount is a view into the same hierarchy. + // + // Note: we're guaranteed to have at least one requested controller, since + // no explicit controller name implies all controllers. + if vfsfs := r.FindHierarchy(wantControllers); vfsfs != nil { + fs := vfsfs.Impl().(*filesystem) + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID) + fs.root.IncRef() + return vfsfs, fs.root.VFSDentry(), nil + } + + // No existing hierarchy with the exactly controllers found. Make a new + // one. Note that it's possible this mount creation is unsatisfiable, if one + // or more of the requested controllers are already on existing + // hierarchies. We'll find out about such collisions when we try to register + // the new hierarchy later. + fs := &filesystem{ + devMinor: devMinor, + } + fs.MaxCachedDentries = maxCachedDentries + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + var defaults map[string]int64 + if opts.InternalData != nil { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults) + defaults = opts.InternalData.(*InternalData).DefaultControlValues + } + + for _, ty := range wantControllers { + var c controller + switch ty { + case controllerCPU: + c = newCPUController(fs, defaults) + case controllerCPUAcct: + c = newCPUAcctController(fs) + case controllerCPUSet: + c = newCPUSetController(fs) + case controllerJob: + c = newJobController(fs) + case controllerMemory: + c = newMemoryController(fs, defaults) + default: + panic(fmt.Sprintf("Unreachable: unknown cgroup controller %q", ty)) + } + fs.controllers = append(fs.controllers, c) + } + + if len(defaults) != 0 { + // Internal data is always provided at sentry startup and unused values + // indicate a problem with the sandbox config. Fail fast. + panic(fmt.Sprintf("cgroupfs.FilesystemType.GetFilesystem: unknown internal mount data: %v", defaults)) + } + + // Controllers usually appear in alphabetical order when displayed. Sort it + // here now, so it never needs to be sorted elsewhere. + sort.Slice(fs.controllers, func(i, j int) bool { return fs.controllers[i].Type() < fs.controllers[j].Type() }) + fs.kcontrollers = make([]kernel.CgroupController, 0, len(fs.controllers)) + for _, c := range fs.controllers { + fs.kcontrollers = append(fs.kcontrollers, c) + } + + root := fs.newCgroupInode(ctx, creds) + var rootD kernfs.Dentry + rootD.InitRoot(&fs.Filesystem, root) + fs.root = &rootD + + // Register controllers. The registry may be modified concurrently, so if we + // get an error, we raced with someone else who registered the same + // controllers first. + hid, err := r.Register(fs.kcontrollers) + if err != nil { + ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err) + rootD.DecRef(ctx) + fs.VFSFilesystem().DecRef(ctx) + return nil, nil, syserror.EBUSY + } + fs.hierarchyID = hid + + // Move all existing tasks to the root of the new hierarchy. + k.PopulateNewCgroupHierarchy(fs.rootCgroup()) + + return fs.VFSFilesystem(), rootD.VFSDentry(), nil +} + +func (fs *filesystem) rootCgroup() kernel.Cgroup { + return kernel.Cgroup{ + Dentry: fs.root, + CgroupImpl: fs.root.Inode().(kernel.CgroupImpl), + } +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release(ctx context.Context) { + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + if fs.hierarchyID != kernel.InvalidCgroupHierarchyID { + k.ReleaseCgroupHierarchy(fs.hierarchyID) + r.Unregister(fs.hierarchyID) + } + + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release(ctx) +} + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + var cnames []string + for _, c := range fs.controllers { + cnames = append(cnames, string(c.Type())) + } + return strings.Join(cnames, ",") +} + +// +stateify savable +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// dir implements kernfs.Inode for a generic cgroup resource controller +// directory. Specific controllers extend this to add their own functionality. +// +// +stateify savable +type dir struct { + dirRefs + kernfs.InodeAlwaysValid + kernfs.InodeAttrs + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir. + kernfs.OrderedChildren + implStatFS + + locks vfs.FileLocks +} + +// Keep implements kernfs.Inode.Keep. +func (*dir) Keep() bool { + return true +} + +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. +func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Open implements kernfs.Inode.Open. +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) + if err != nil { + return nil, err + } + return fd.VFSFileDescription(), nil +} + +// DecRef implements kernfs.Inode.DecRef. +func (d *dir) DecRef(ctx context.Context) { + d.dirRefs.DecRef(func() { d.Destroy(ctx) }) +} + +// StatFS implements kernfs.Inode.StatFS. +func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// controllerFile represents a generic control file that appears within a cgroup +// directory. +// +// +stateify savable +type controllerFile struct { + kernfs.DynamicBytesFile +} + +func (fs *filesystem) newControllerFile(ctx context.Context, creds *auth.Credentials, data vfs.DynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, readonlyFileMode) + return f +} + +func (fs *filesystem) newControllerWritableFile(ctx context.Context, creds *auth.Credentials, data vfs.WritableDynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, writableFileMode) + return f +} + +// staticControllerFile represents a generic control file that appears within a +// cgroup directory which always returns the same data when read. +// staticControllerFiles are not writable. +// +// +stateify savable +type staticControllerFile struct { + kernfs.DynamicBytesFile + vfs.StaticData +} + +// Note: We let the caller provide the mode so that static files may be used to +// fake both readable and writable control files. However, static files are +// effectively readonly, as attempting to write to them will return EIO +// regardless of the mode. +func (fs *filesystem) newStaticControllerFile(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, data string) kernfs.Inode { + f := &staticControllerFile{StaticData: vfs.StaticData{Data: data}} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), f, mode) + return f +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go new file mode 100644 index 000000000..24d86a277 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpuController struct { + controllerCommon + + // CFS bandwidth control parameters, values in microseconds. + cfsPeriod int64 + cfsQuota int64 + + // CPU shares, values should be (num core * 1024). + shares int64 +} + +var _ controller = (*cpuController)(nil) + +func newCPUController(fs *filesystem, defaults map[string]int64) *cpuController { + // Default values for controller parameters from Linux. + c := &cpuController{ + cfsPeriod: 100000, + cfsQuota: -1, + shares: 1024, + } + + if val, ok := defaults["cpu.cfs_period_us"]; ok { + c.cfsPeriod = val + delete(defaults, "cpu.cfs_period_us") + } + if val, ok := defaults["cpu.cfs_quota_us"]; ok { + c.cfsQuota = val + delete(defaults, "cpu.cfs_quota_us") + } + if val, ok := defaults["cpu.shares"]; ok { + c.shares = val + delete(defaults, "cpu.shares") + } + + c.controllerCommon.init(controllerCPU, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["cpu.cfs_period_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsPeriod)) + contents["cpu.cfs_quota_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsQuota)) + contents["cpu.shares"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.shares)) +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go new file mode 100644 index 000000000..d4104a00e --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go @@ -0,0 +1,114 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type cpuacctController struct { + controllerCommon +} + +var _ controller = (*cpuacctController)(nil) + +func newCPUAcctController(fs *filesystem) *cpuacctController { + c := &cpuacctController{} + c.controllerCommon.init(controllerCPUAcct, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuacctController) AddControlFiles(ctx context.Context, creds *auth.Credentials, cg *cgroupInode, contents map[string]kernfs.Inode) { + cpuacctCG := &cpuacctCgroup{cg} + contents["cpuacct.stat"] = c.fs.newControllerFile(ctx, creds, &cpuacctStatData{cpuacctCG}) + contents["cpuacct.usage"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageData{cpuacctCG}) + contents["cpuacct.usage_user"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageUserData{cpuacctCG}) + contents["cpuacct.usage_sys"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageSysData{cpuacctCG}) +} + +// +stateify savable +type cpuacctCgroup struct { + *cgroupInode +} + +func (c *cpuacctCgroup) collectCPUStats() usage.CPUStats { + var cs usage.CPUStats + c.fs.tasksMu.RLock() + // Note: This isn't very accurate, since the tasks are potentially + // still running as we accumulate their stats. + for t := range c.ts { + cs.Accumulate(t.CPUStats()) + } + c.fs.tasksMu.RUnlock() + return cs +} + +// +stateify savable +type cpuacctStatData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "user %d\n", linux.ClockTFromDuration(cs.UserTime)) + fmt.Fprintf(buf, "system %d\n", linux.ClockTFromDuration(cs.SysTime)) + return nil +} + +// +stateify savable +type cpuacctUsageData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()+cs.SysTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageUserData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageUserData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageSysData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageSysData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.SysTime.Nanoseconds()) + return nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go new file mode 100644 index 000000000..ac547f8e2 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go @@ -0,0 +1,39 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpusetController struct { + controllerCommon +} + +var _ controller = (*cpusetController)(nil) + +func newCPUSetController(fs *filesystem) *cpusetController { + c := &cpusetController{} + c.controllerCommon.init(controllerCPUSet, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + // This controller is currently intentionally empty. +} diff --git a/pkg/sentry/fsimpl/cgroupfs/job.go b/pkg/sentry/fsimpl/cgroupfs/job.go new file mode 100644 index 000000000..48919c338 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/job.go @@ -0,0 +1,64 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/usermem" +) + +// +stateify savable +type jobController struct { + controllerCommon + id int64 +} + +var _ controller = (*jobController)(nil) + +func newJobController(fs *filesystem) *jobController { + c := &jobController{} + c.controllerCommon.init(controllerJob, fs) + return c +} + +func (c *jobController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["job.id"] = c.fs.newControllerWritableFile(ctx, creds, &jobIDData{c: c}) +} + +// +stateify savable +type jobIDData struct { + c *jobController +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *jobIDData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%d\n", d.c.id) + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *jobIDData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + val, n, err := parseInt64FromString(ctx, src, offset) + if err != nil { + return n, err + } + d.c.id = val + return n, nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go new file mode 100644 index 000000000..485c98376 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/memory.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "math" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type memoryController struct { + controllerCommon + + limitBytes int64 +} + +var _ controller = (*memoryController)(nil) + +func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController { + c := &memoryController{ + // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which + // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact + // value isn't very important. + limitBytes: math.MaxInt64, + } + if val, ok := defaults["memory.limit_in_bytes"]; ok { + c.limitBytes = val + delete(defaults, "memory.limit_in_bytes") + } + c.controllerCommon.init(controllerMemory, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{}) + contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes)) +} + +// +stateify savable +type memoryUsageInBytesData struct{} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *memoryUsageInBytesData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // TODO(b/183151557): This is a giant hack, we're using system-wide + // accounting since we know there is only one cgroup. + k := kernel.KernelFromContext(ctx) + mf := k.MemoryFile() + mf.UpdateUsage() + _, totalBytes := usage.MemoryAccounting.Copy() + + fmt.Fprintf(buf, "%d\n", totalBytes) + return nil +} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 7b1eec3da..2dbc6bfd5 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -46,7 +46,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fd", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 6d5258a9b..52879f871 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "revalidate.go", "save_restore.go", "socket.go", "special_file.go", diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 43c3c5a2d..97ce80853 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -117,6 +117,17 @@ func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry { return ds } +// Precondition: !parent.isSynthetic() && !child.isSynthetic(). +func appendNewChildDentry(ds **[]*dentry, parent *dentry, child *dentry) { + // The new child was added to parent and took a ref on the parent (hence + // parent can be removed from cache). A new child has 0 refs for now. So + // checkCachingLocked() should be called on both. Call it first on the parent + // as it may create space in the cache for child to be inserted - hence + // avoiding a cache eviction. + *ds = appendDentry(*ds, parent) + *ds = appendDentry(*ds, child) +} + // Preconditions: ds != nil. func putDentrySlice(ds *[]*dentry) { // Allow dentries to be GC'd. @@ -141,21 +152,8 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp ** return } ds := **dsp - // Only go through calling dentry.checkCachingLocked() (which requires - // re-locking renameMu) if we actually have any dentries with zero refs. - checkAny := false - for i := range ds { - if atomic.LoadInt64(&ds[i].refs) == 0 { - checkAny = true - break - } - } - if checkAny { - fs.renameMu.Lock() - for _, d := range ds { - d.checkCachingLocked(ctx) - } - fs.renameMu.Unlock() + for _, d := range ds { + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } putDentrySlice(*dsp) } @@ -166,7 +164,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[] return } for _, d := range **ds { - d.checkCachingLocked(ctx) + d.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } fs.renameMu.Unlock() putDentrySlice(*ds) @@ -182,165 +180,96 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[] // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up -// to date. +// * If !d.cachedMetadataAuthoritative(), then d and all children that are +// part of rp must have been revalidated. // // Postconditions: The returned dentry's cached metadata is up to date. -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, bool, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, false, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, false, err } + followedSymlink := false afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, followedSymlink, nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, false, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil - } - // We must assume that d.parent is correct, because if d has been moved - // elsewhere in the remote filesystem so that its parent has changed, - // we have no way of determining its new parent's location in the - // filesystem. - // - // Call rp.CheckMount() before updating d.parent's metadata, since if - // we traverse to another mount then d.parent's metadata is irrelevant. - if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return d, followedSymlink, nil } - if d != d.parent && !d.cachedMetadataAuthoritative() { - if err := d.parent.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, false, err } rp.Advance() - return d.parent, nil + return d.parent, followedSymlink, nil } - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds) + child, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err - } - if child == nil { - return nil, syserror.ENOENT + return nil, false, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, false, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx, rp.Mount()) if err != nil { - return nil, err + return nil, false, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, false, err } + followedSymlink = true goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, followedSymlink, nil } // getChildLocked returns a dentry representing the child of parent with the -// given name. If no such child exists, getChildLocked returns (nil, nil). +// given name. Returns ENOENT if the child doesn't exist. // // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. // * parent.isDir(). // * name is not "." or "..". -// -// Postconditions: If getChildLocked returns a non-nil dentry, its cached -// metadata is up to date. -func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +// * dentry at name has been revalidated +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if len(name) > maxFilenameLen { return nil, syserror.ENAMETOOLONG } - child, ok := parent.children[name] - if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() { - // Whether child is nil or not, it is cached information that is - // assumed to be correct. + if child, ok := parent.children[name]; ok || parent.isSynthetic() { + if child == nil { + return nil, syserror.ENOENT + } return child, nil } - // We either don't have cached information or need to verify that it's - // still correct, either of which requires a remote lookup. Check if this - // name is valid before performing the lookup. - return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds) -} -// Preconditions: Same as getChildLocked, plus: -// * !parent.isSynthetic(). -func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { - if child != nil { - // Need to lock child.metadataMu because we might be updating child - // metadata. We need to hold the lock *before* getting metadata from the - // server and release it after updating local metadata. - child.metadataMu.Lock() - } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) - if err != nil && err != syserror.ENOENT { - if child != nil { - child.metadataMu.Unlock() + if err != nil { + if err == syserror.ENOENT { + parent.cacheNegativeLookupLocked(name) } return nil, err } - if child != nil { - if !file.isNil() && qid.Path == child.qidPath { - // The file at this path hasn't changed. Just update cached metadata. - file.close(ctx) - child.updateFromP9AttrsLocked(attrMask, &attr) - child.metadataMu.Unlock() - return child, nil - } - child.metadataMu.Unlock() - if file.isNil() && child.isSynthetic() { - // We have a synthetic file, and no remote file has arisen to - // replace it. - return child, nil - } - // The file at this path has changed or no longer exists. Mark the - // dentry invalidated, and re-evaluate its caching status (i.e. if it - // has 0 references, drop it). Wait to update parent.children until we - // know what to replace the existing dentry with (i.e. one of the - // returns below), to avoid a redundant map access. - vfsObj.InvalidateDentry(ctx, &child.vfsd) - if child.isSynthetic() { - // Normally we don't mark invalidated dentries as deleted since - // they may still exist (but at a different path), and also for - // consistency with Linux. However, synthetic files are guaranteed - // to become unreachable if their dentries are invalidated, so - // treat their invalidation as deletion. - child.setDeleted() - parent.syntheticChildren-- - child.decRefNoCaching() - parent.dirents = nil - } - *ds = appendDentry(*ds, child) - } - if file.isNil() { - // No file exists at this path now. Cache the negative lookup if - // allowed. - parent.cacheNegativeLookupLocked(name) - return nil, nil - } + // Create a new dentry representing the file. - child, err = fs.newDentry(ctx, file, qid, attrMask, &attr) + child, err := fs.newDentry(ctx, file, qid, attrMask, &attr) if err != nil { file.close(ctx) delete(parent.children, name) return nil, err } parent.cacheNewChildLocked(child, name) - // For now, child has 0 references, so our caller should call - // child.checkCachingLocked(). - *ds = appendDentry(*ds, child) + appendNewChildDentry(ds, parent, child) return child, nil } @@ -355,14 +284,22 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up // to date. func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if !d.isDir() { return nil, syserror.ENOTDIR @@ -375,20 +312,22 @@ func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: fs.renameMu must be locked. func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { d := rp.Start().Impl().(*dentry) - if !d.cachedMetadataAuthoritative() { - // Get updated metadata for rp.Start() as required by fs.stepLocked(). - if err := d.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err } for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if rp.MustBeDir() && !d.isDir() { return nil, syserror.ENOTDIR @@ -408,13 +347,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -432,25 +364,47 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if parent.isDeleted() { return syserror.ENOENT } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil { + return err + } parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds) - switch { - case err != nil && err != syserror.ENOENT: - return err - case child != nil: + if len(name) > maxFilenameLen { + return syserror.ENAMETOOLONG + } + // Check for existence only if caching information is available. Otherwise, + // don't check for existence just yet. We will check for existence if the + // checks for writability fail below. Existence check is done by the creation + // RPCs themselves. + if child, ok := parent.children[name]; ok && child != nil { return syserror.EEXIST } + checkExistence := func() error { + if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT { + return err + } else if child != nil { + return syserror.EEXIST + } + return nil + } mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } defer mnt.EndWrite() if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } if !dir && rp.MustBeDir() { @@ -500,13 +454,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -532,33 +479,32 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return syserror.EISDIR } } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, parent, rp.Component(), &ds); err != nil { + return err + } + mntns := vfs.MountNamespaceFromContext(ctx) defer mntns.DecRef(ctx) + parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, ok := parent.children[name] - if ok && child == nil { - return syserror.ENOENT - } - - sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0 - if sticky { - if !ok { - // If the sticky bit is set, we need to retrieve the child to determine - // whether removing it is allowed. - child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err != nil { - return err - } - } else if child != nil && !child.cachedMetadataAuthoritative() { - // Make sure the dentry representing the file at name is up to date - // before examining its metadata. - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } + // Load child if sticky bit is set because we need to determine whether + // deletion is allowed. + var child *dentry + if atomic.LoadUint32(&parent.mode)&linux.ModeSticky == 0 { + var ok bool + child, ok = parent.children[name] + if ok && child == nil { + // Hit a negative cached entry, child doesn't exist. + return syserror.ENOENT + } + } else { + child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + if err != nil { + return err } if err := parent.mayDelete(rp.Credentials(), child); err != nil { return err @@ -567,11 +513,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // If a child dentry exists, prepare to delete it. This should fail if it is // a mount point. We detect mount points by speculatively calling - // PrepareDeleteDentry, which fails if child is a mount point. However, we - // may need to revalidate the file in this case to make sure that it has not - // been deleted or replaced on the remote fs, in which case the mount point - // will have disappeared. If calling PrepareDeleteDentry fails again on the - // up-to-date dentry, we can be sure that it is a mount point. + // PrepareDeleteDentry, which fails if child is a mount point. // // Also note that if child is nil, then it can't be a mount point. if child != nil { @@ -586,23 +528,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child.dirMu.Lock() defer child.dirMu.Unlock() if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - // We can skip revalidation in several cases: - // - We are not in InteropModeShared - // - The parent directory is synthetic, in which case the child must also - // be synthetic - // - We already updated the child during the sticky bit check above - if parent.cachedMetadataAuthoritative() || sticky { - return err - } - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } - if child != nil { - if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - return err - } - } + return err } } flags := uint32(0) @@ -723,6 +649,8 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op } } d.IncRef() + // Call d.checkCachingLocked() so it can be removed from the cache if needed. + ds = appendDentry(ds, d) return &d.vfsd, nil } @@ -732,18 +660,13 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return nil, err } d.IncRef() + // Call d.checkCachingLocked() so it can be removed from the cache if needed. + ds = appendDentry(ds, d) return &d.vfsd, nil } @@ -782,7 +705,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { // If the parent is a setgid directory, use the parent's GID // rather than the caller's and enable setgid. kgid := creds.EffectiveKGID @@ -802,6 +725,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kuid: creds.EffectiveKUID, kgid: creds.EffectiveKGID, }) + *ds = appendDentry(*ds, parent) } if fs.opts.interop != InteropModeShared { parent.incLinks() @@ -836,7 +760,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // to creating a synthetic one, i.e. one that is kept entirely in memory. // Check that we're not overriding an existing file with a synthetic one. - _, err = fs.stepLocked(ctx, rp, parent, true, ds) + _, _, err = fs.stepLocked(ctx, rp, parent, true, ds) switch { case err == nil: // Step succeeded, another file exists. @@ -855,6 +779,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid: creds.EffectiveKGID, endpoint: opts.Endpoint, }) + *ds = appendDentry(*ds, parent) return nil case linux.S_IFIFO: parent.createSyntheticChildLocked(&createSyntheticOpts{ @@ -864,6 +789,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid: creds.EffectiveKGID, pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) + *ds = appendDentry(*ds, parent) return nil } // Retain error from gofer if synthetic file cannot be created internally. @@ -895,12 +821,6 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf defer unlock() start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by fs.stepLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } if rp.Done() { // Reject attempts to open mount root directory with O_CREAT. if mayCreate && rp.MustBeDir() { @@ -909,9 +829,17 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } + if !start.cachedMetadataAuthoritative() { + // Refresh dentry's attributes before opening. + if err := start.updateFromGetattr(ctx); err != nil { + return nil, err + } + } start.IncRef() defer start.DecRef(ctx) unlock() + // start is intentionally not added to ds (which would remove it from the + // cache) because doing so regresses performance in practice. return start.open(ctx, rp, &opts) } @@ -928,9 +856,12 @@ afterTrailingSymlink: if mayCreate && rp.MustBeDir() { return nil, syserror.EISDIR } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil { + return nil, err + } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { if parent.isSynthetic() { parent.dirMu.Unlock() @@ -965,6 +896,8 @@ afterTrailingSymlink: child.IncRef() defer child.DecRef(ctx) unlock() + // child is intentionally not added to ds (which would remove it from the + // cache) because doing so regresses performance in practice. return child.open(ctx, rp, &opts) } @@ -1188,7 +1121,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } - *ds = appendDentry(*ds, child) // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { @@ -1212,6 +1144,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } // Insert the dentry into the tree. d.cacheNewChildLocked(child, name) + appendNewChildDentry(ds, d, child) if d.cachedMetadataAuthoritative() { d.touchCMtime() d.dirents = nil @@ -1296,18 +1229,23 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, newParent, newName, &ds); err != nil { + return err + } + if err := fs.revalidateOne(ctx, vfsObj, oldParent, oldName, &ds); err != nil { + return err + } + // We need a dentry representing the renamed file since, if it's a // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds) + renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } - if renamed == nil { - return syserror.ENOENT - } if err := oldParent.mayDelete(creds, renamed); err != nil { return err } @@ -1336,8 +1274,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.isDeleted() { return syserror.ENOENT } - replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds) - if err != nil { + replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { return err } var replacedVFSD *vfs.Dentry @@ -1401,8 +1339,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // parent isn't actually changing. if oldParent != newParent { oldParent.decRefNoCaching() - ds = appendDentry(ds, oldParent) newParent.IncRef() + ds = appendDentry(ds, newParent) + ds = appendDentry(ds, oldParent) if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ @@ -1546,6 +1485,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath if d.isSocket() { if !d.isSynthetic() { d.IncRef() + ds = appendDentry(ds, d) return &endpoint{ dentry: d, path: opts.Addr, diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index a0c05231a..21692d2ac 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -18,21 +18,23 @@ // Lock order: // regularFileFD/directoryFD.mu // filesystem.renameMu -// dentry.dirMu -// filesystem.syncMu -// dentry.metadataMu -// *** "memmap.Mappable locks" below this point -// dentry.mapsMu -// *** "memmap.Mappable locks taken by Translate" below this point -// dentry.handleMu -// dentry.dataMu -// filesystem.inoMu +// dentry.cachingMu +// filesystem.cacheMu +// dentry.dirMu +// filesystem.syncMu +// dentry.metadataMu +// *** "memmap.Mappable locks" below this point +// dentry.mapsMu +// *** "memmap.Mappable locks taken by Translate" below this point +// dentry.handleMu +// dentry.dataMu +// filesystem.inoMu // specialFileFD.mu // specialFileFD.bufMu // -// Locking dentry.dirMu in multiple dentries requires that either ancestor -// dentries are locked before descendant dentries, or that filesystem.renameMu -// is locked for writing. +// Locking dentry.dirMu and dentry.metadataMu in multiple dentries requires that +// either ancestor dentries are locked before descendant dentries, or that +// filesystem.renameMu is locked for writing. package gofer import ( @@ -140,7 +142,8 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) // cachedDentriesLen is the number of dentries in cachedDentries. These fields - // are protected by renameMu. + // are protected by cacheMu. + cacheMu sync.Mutex `state:"nosave"` cachedDentries dentryList cachedDentriesLen uint64 @@ -620,11 +623,11 @@ func (fs *filesystem) Release(ctx context.Context) { // the reference count on every synthetic dentry. Synthetic dentries have one // reference for existence that should be dropped during filesystem.Release. // -// Precondition: d.fs.renameMu is locked. +// Precondition: d.fs.renameMu is locked for writing. func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { if d.isSynthetic() { d.decRefNoCaching() - d.checkCachingLocked(ctx) + d.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } if d.isDir() { var children []*dentry @@ -682,9 +685,13 @@ type dentry struct { // deleted. deleted is accessed using atomic memory operations. deleted uint32 + // cachingMu is used to synchronize concurrent dentry caching attempts on + // this dentry. + cachingMu sync.Mutex `state:"nosave"` + // If cached is true, dentryEntry links dentry into // filesystem.cachedDentries. cached and dentryEntry are protected by - // filesystem.renameMu. + // cachingMu. cached bool dentryEntry @@ -980,36 +987,63 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } // Preconditions: !d.isSynthetic(). +// Preconditions: d.metadataMu is locked. +func (d *dentry) refreshSizeLocked(ctx context.Context) error { + d.handleMu.RLock() + + if d.writeFD < 0 { + d.handleMu.RUnlock() + // Ask the gofer if we don't have a host FD. + return d.updateFromGetattrLocked(ctx) + } + + var stat unix.Statx_t + err := unix.Statx(int(d.writeFD), "", unix.AT_EMPTY_PATH, unix.STATX_SIZE, &stat) + d.handleMu.RUnlock() // must be released before updateSizeLocked() + if err != nil { + return err + } + d.updateSizeLocked(stat.Size) + return nil +} + +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { - // Use d.readFile or d.writeFile, which represent 9P fids that have been + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + return d.updateFromGetattrLocked(ctx) +} + +// Preconditions: +// * !d.isSynthetic(). +// * d.metadataMu is locked. +func (d *dentry) updateFromGetattrLocked(ctx context.Context) error { + // Use d.readFile or d.writeFile, which represent 9P FIDs that have been // opened, in preference to d.file, which represents a 9P fid that has not. // This may be significantly more efficient in some implementations. Prefer // d.writeFile over d.readFile since some filesystem implementations may // update a writable handle's metadata after writes to that handle, without // making metadata updates immediately visible to read-only handles // representing the same file. - var ( - file p9file - handleMuRLocked bool - ) - // d.metadataMu must be locked *before* we getAttr so that we do not end up - // updating stale attributes in d.updateFromP9AttrsLocked(). - d.metadataMu.Lock() - defer d.metadataMu.Unlock() d.handleMu.RLock() - if !d.writeFile.isNil() { + handleMuRLocked := true + var file p9file + switch { + case !d.writeFile.isNil(): file = d.writeFile - handleMuRLocked = true - } else if !d.readFile.isNil() { + case !d.readFile.isNil(): file = d.readFile - handleMuRLocked = true - } else { + default: file = d.file d.handleMu.RUnlock() + handleMuRLocked = false } + _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask()) if handleMuRLocked { - d.handleMu.RUnlock() + d.handleMu.RUnlock() // must be released before updateFromP9AttrsLocked() } if err != nil { return err @@ -1104,24 +1138,27 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs defer d.metadataMu.Unlock() // As with Linux, if the UID, GID, or file size is changing, we have to - // clear permission bits. Note that when set, clearSGID causes - // permissions to be updated, but does not modify stat.Mask, as - // modification would cause an extra inotify flag to be set. - clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) || - stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) || + // clear permission bits. Note that when set, clearSGID may cause + // permissions to be updated. + clearSGID := (stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid)) || + (stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid)) || stat.Mask&linux.STATX_SIZE != 0 if clearSGID { if stat.Mask&linux.STATX_MODE != 0 { stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode))) } else { - stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode))) + oldMode := atomic.LoadUint32(&d.mode) + if updatedMode := vfs.ClearSUIDAndSGID(oldMode); updatedMode != oldMode { + stat.Mode = uint16(updatedMode) + stat.Mask |= linux.STATX_MODE + } } } if !d.isSynthetic() { if stat.Mask != 0 { if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID, + Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, GID: stat.Mask&linux.STATX_GID != 0, Size: stat.Mask&linux.STATX_SIZE != 0, @@ -1156,7 +1193,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 || clearSGID { + if stat.Mask&linux.STATX_MODE != 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } if stat.Mask&linux.STATX_UID != 0 { @@ -1312,9 +1349,7 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { if d.decRefNoCaching() == 0 { - d.fs.renameMu.Lock() - d.checkCachingLocked(ctx) - d.fs.renameMu.Unlock() + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } } @@ -1374,15 +1409,16 @@ func (d *dentry) Watches() *vfs.Watches { // // If no watches are left on this dentry and it has no references, cache it. func (d *dentry) OnZeroWatches(ctx context.Context) { - if atomic.LoadInt64(&d.refs) == 0 { - d.fs.renameMu.Lock() - d.checkCachingLocked(ctx) - d.fs.renameMu.Unlock() - } + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } -// checkCachingLocked should be called after d's reference count becomes 0 or it -// becomes disowned. +// checkCachingLocked should be called after d's reference count becomes 0 or +// it becomes disowned. +// +// For performance, checkCachingLocked can also be called after d's reference +// count becomes non-zero, so that d can be removed from the LRU cache. This +// may help in reducing the size of the cache and hence reduce evictions. Note +// that this is not necessary for correctness. // // It may be called on a destroyed dentry. For example, // renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times @@ -1390,33 +1426,46 @@ func (d *dentry) OnZeroWatches(ctx context.Context) { // operation. One of the calls may destroy the dentry, so subsequent calls will // do nothing. // -// Preconditions: d.fs.renameMu must be locked for writing; it may be -// temporarily unlocked. -func (d *dentry) checkCachingLocked(ctx context.Context) { - // Dentries with a non-zero reference count must be retained. (The only way - // to obtain a reference on a dentry with zero references is via path - // resolution, which requires renameMu, so if d.refs is zero then it will - // remain zero while we hold renameMu for writing.) +// Preconditions: d.fs.renameMu must be locked for writing if +// renameMuWriteLocked is true; it may be temporarily unlocked. +func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) { + d.cachingMu.Lock() refs := atomic.LoadInt64(&d.refs) if refs == -1 { // Dentry has already been destroyed. + d.cachingMu.Unlock() return } if refs > 0 { - // This isn't strictly necessary (fs.cachedDentries is permitted to - // contain dentries with non-zero refs, which are skipped by - // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but - // since we are already holding fs.renameMu for writing we may as well. + // fs.cachedDentries is permitted to contain dentries with non-zero refs, + // which are skipped by fs.evictCachedDentryLocked() upon reaching the end + // of the LRU. But it is still beneficial to remove d from the cache as we + // are already holding d.cachingMu. Keeping a cleaner cache also reduces + // the number of evictions (which is expensive as it acquires fs.renameMu). d.removeFromCacheLocked() + d.cachingMu.Unlock() return } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { + d.removeFromCacheLocked() + d.cachingMu.Unlock() + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu for writing as needed by d.destroyLocked(). + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + // Now that renameMu is locked for writing, no more refs can be taken on + // d because path resolution requires renameMu for reading at least. + if atomic.LoadInt64(&d.refs) != 0 { + // Destroy d only if its ref is still 0. If not, either someone took a + // ref on it or it got destroyed before fs.renameMu could be acquired. + return + } + } if d.isDeleted() { d.watches.HandleDeletion(ctx) } - d.removeFromCacheLocked() d.destroyLocked(ctx) return } @@ -1426,24 +1475,36 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // d.watches cannot concurrently transition from zero to non-zero, because // adding a watch requires holding a reference on d. if d.watches.Size() > 0 { - // As in the refs > 0 case, this is not strictly necessary. + // As in the refs > 0 case, removing d is beneficial. d.removeFromCacheLocked() + d.cachingMu.Unlock() return } if atomic.LoadInt32(&d.fs.released) != 0 { + d.cachingMu.Unlock() + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as + // needed by d.destroyLocked() later. + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } if d.parent != nil { d.parent.dirMu.Lock() delete(d.parent.children, d.name) d.parent.dirMu.Unlock() } d.destroyLocked(ctx) + return } + d.fs.cacheMu.Lock() // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentries.PushFront(d) + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() return } // Cache the dentry, then evict the least recently used cached dentry if @@ -1451,18 +1512,28 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentries.PushFront(d) d.fs.cachedDentriesLen++ d.cached = true - if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { + shouldEvict := d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() + + if shouldEvict { + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu for writing as needed by + // d.evictCachedDentryLocked(). + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } d.fs.evictCachedDentryLocked(ctx) - // Whether or not victim was destroyed, we brought fs.cachedDentriesLen - // back down to fs.opts.maxCachedDentries, so we don't loop. } } -// Preconditions: d.fs.renameMu must be locked for writing. +// Preconditions: d.cachingMu must be locked. func (d *dentry) removeFromCacheLocked() { if d.cached { + d.fs.cacheMu.Lock() d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- + d.fs.cacheMu.Unlock() d.cached = false } } @@ -1477,28 +1548,43 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { // Preconditions: // * fs.renameMu must be locked for writing; it may be temporarily unlocked. -// * fs.cachedDentriesLen != 0. func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + fs.cacheMu.Lock() victim := fs.cachedDentries.Back() + fs.cacheMu.Unlock() + if victim == nil { + // fs.cachedDentries may have become empty between when it was checked and + // when we locked fs.cacheMu. + return + } + + victim.cachingMu.Lock() victim.removeFromCacheLocked() // victim.refs or victim.watches.Size() may have become non-zero from an // earlier path resolution since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() + if atomic.LoadInt64(&victim.refs) != 0 || victim.watches.Size() != 0 { + victim.cachingMu.Unlock() + return + } + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. } - victim.destroyLocked(ctx) + victim.parent.dirMu.Unlock() } + // Safe to unlock cachingMu now that victim.vfsd.IsDead(). Henceforth any + // concurrent caching attempts on victim will attempt to destroy it and so + // will try to acquire fs.renameMu (which we have already acquired). Hence, + // fs.renameMu will synchronize the destroy attempts. + victim.cachingMu.Unlock() + victim.destroyLocked(ctx) } // destroyLocked destroys the dentry. @@ -1584,7 +1670,7 @@ func (d *dentry) destroyLocked(ctx context.Context) { // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. if d.parent != nil && d.parent.decRefNoCaching() == 0 { - d.parent.checkCachingLocked(ctx) + d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } refsvfs2.Unregister(d) } diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index 76f08e252..806392d50 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -55,7 +55,7 @@ func TestDestroyIdempotent(t *testing.T) { fs.renameMu.Lock() defer fs.renameMu.Unlock() - child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) if got := atomic.LoadInt64(&child.refs); got != -1 { t.Fatalf("child.refs=%d, want: -1", got) } @@ -63,6 +63,6 @@ func TestDestroyIdempotent(t *testing.T) { if got := atomic.LoadInt64(&parent.refs); got != -1 { t.Fatalf("parent.refs=%d, want: -1", got) } - child.checkCachingLocked(ctx) - child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 21b4a96fe..b0a429d42 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -238,3 +238,10 @@ func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, err ctx.UninterruptibleSleepFinish(false) return fdobj, err } + +func (f p9file) multiGetAttr(ctx context.Context, names []string) ([]p9.FullStat, error) { + ctx.UninterruptibleSleepStart(false) + stats, err := f.file.MultiGetAttr(names) + ctx.UninterruptibleSleepFinish(false) + return stats, err +} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 47563538c..f0e7bbaf7 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -204,18 +204,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() + + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + // If the fd was opened with O_APPEND, make sure the file size is updated. // There is a possible race here if size is modified externally after // metadata cache is updated. if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { + if err := d.refreshSizeLocked(ctx); err != nil { return 0, offset, err } } - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - // Set offset to file size if the fd was opened with O_APPEND. if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { // Holding d.metadataMu is sufficient for reading d.size. @@ -701,6 +702,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt } // After this point, d may be used as a memmap.Mappable. d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init) + opts.SentryOwnedContent = d.fs.opts.forcePageCache return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts) } diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go new file mode 100644 index 000000000..8f81f0822 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/revalidate.go @@ -0,0 +1,386 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +type errPartialRevalidation struct{} + +// Error implements error.Error. +func (errPartialRevalidation) Error() string { + return "partial revalidation" +} + +type errRevalidationStepDone struct{} + +// Error implements error.Error. +func (errRevalidationStepDone) Error() string { + return "stop revalidation" +} + +// revalidatePath checks cached dentries for external modification. File +// attributes are refreshed and cache is invalidated in case the dentry has been +// deleted, or a new file/directory created in its place. +// +// Revalidation stops at symlinks and mount points. The caller is responsible +// for revalidating again after symlinks are resolved and after changing to +// different mounts. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidatePath(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Done, ds) + rp.Release(ctx) + return err +} + +// revalidateParentDir does the same as revalidatePath, but stops at the parent. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateParentDir(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file and parent is non synthetic. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Final, ds) + rp.Release(ctx) + return err +} + +// revalidateOne does the same as revalidatePath, but checks a single dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateOne(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) error { + // Skip revalidation for interop mode different than InteropModeShared or + // if the parent is synthetic (child must be synthetic too, but it cannot be + // replaced without first replacing the parent). + if parent.cachedMetadataAuthoritative() { + return nil + } + + parent.dirMu.Lock() + child, ok := parent.children[name] + parent.dirMu.Unlock() + if !ok { + return nil + } + + state := makeRevalidateState(parent) + defer state.release() + + state.add(name, child) + return fs.revalidateHelper(ctx, vfsObj, state, ds) +} + +// revalidate revalidates path components in rp until done returns true, or +// until a mount point or symlink is reached. It may send multiple MultiGetAttr +// calls to the gofer to handle ".." in the path. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidate(ctx context.Context, rp *vfs.ResolvingPath, start *dentry, done func() bool, ds **[]*dentry) error { + state := makeRevalidateState(start) + defer state.release() + + // Skip synthetic dentries because the start dentry cannot be replaced in case + // it has been created in the remote file system. + if !start.isSynthetic() { + state.add("", start) + } + +done: + for cur := start; !done(); { + var err error + cur, err = fs.revalidateStep(ctx, rp, cur, state) + if err != nil { + switch err.(type) { + case errPartialRevalidation: + if err := fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds); err != nil { + return err + } + + // Reset state to release any remaining locks and restart from where + // stepping stopped. + state.reset() + state.start = cur + + // Skip synthetic dentries because the start dentry cannot be replaced in + // case it has been created in the remote file system. + if !cur.isSynthetic() { + state.add("", cur) + } + + case errRevalidationStepDone: + break done + + default: + return err + } + } + } + return fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds) +} + +// revalidateStep walks one element of the path and updates revalidationState +// with the dentry if needed. It may also stop the stepping or ask for a +// partial revalidation. Partial revalidation requires the caller to revalidate +// the current revalidationState, release all locks, and resume stepping. +// In case a symlink is hit, revalidation stops and the caller is responsible +// for calling revalidate again after the symlink is resolved. Revalidation may +// also stop for other reasons, like hitting a child not in the cache. +// +// Returns: +// * (dentry, nil): step worked, continue stepping.` +// * (dentry, errPartialRevalidation): revalidation should be done with the +// state gathered so far. Then continue stepping with the remainder of the +// path, starting at `dentry`. +// * (nil, errRevalidationStepDone): revalidation doesn't need to step any +// further. It hit a symlink, a mount point, or an uncached dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). +// * InteropModeShared is in effect (assumes no negative dentries). +func (fs *filesystem) revalidateStep(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, state *revalidateState) (*dentry, error) { + switch name := rp.Component(); name { + case ".": + // Do nothing. + + case "..": + // Partial revalidation is required when ".." is hit because metadata locks + // can only be acquired from parent to child to avoid deadlocks. + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } else if isRoot || d.parent == nil { + rp.Advance() + return d, errPartialRevalidation{} + } + // We must assume that d.parent is correct, because if d has been moved + // elsewhere in the remote filesystem so that its parent has changed, + // we have no way of determining its new parent's location in the + // filesystem. + // + // Call rp.CheckMount() before updating d.parent's metadata, since if + // we traverse to another mount then d.parent's metadata is irrelevant. + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } + rp.Advance() + return d.parent, errPartialRevalidation{} + + default: + d.dirMu.Lock() + child, ok := d.children[name] + d.dirMu.Unlock() + if !ok { + // child is not cached, no need to validate any further. + return nil, errRevalidationStepDone{} + } + + state.add(name, child) + + // Symlink must be resolved before continuing with revalidation. + if child.isSymlink() { + return nil, errRevalidationStepDone{} + } + + d = child + } + + rp.Advance() + return d, nil +} + +// revalidateHelper calls the gofer to stat all dentries in `state`. It will +// update or invalidate dentries in the cache based on the result. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualFilesystem, state *revalidateState, ds **[]*dentry) error { + if len(state.names) == 0 { + return nil + } + // Lock metadata on all dentries *before* getting attributes for them. + state.lockAllMetadata() + stats, err := state.start.file.multiGetAttr(ctx, state.names) + if err != nil { + return err + } + + i := -1 + for d := state.popFront(); d != nil; d = state.popFront() { + i++ + found := i < len(stats) + if i == 0 && len(state.names[0]) == 0 { + if found && !d.isSynthetic() { + // First dentry is where the search is starting, just update attributes + // since it cannot be replaced. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + } + d.metadataMu.Unlock() + continue + } + + // Note that synthetic dentries will always fails the comparison check + // below. + if !found || d.qidPath != stats[i].QID.Path { + d.metadataMu.Unlock() + if !found && d.isSynthetic() { + // We have a synthetic file, and no remote file has arisen to replace + // it. + return nil + } + // The file at this path has changed or no longer exists. Mark the + // dentry invalidated, and re-evaluate its caching status (i.e. if it + // has 0 references, drop it). The dentry will be reloaded next time it's + // accessed. + vfsObj.InvalidateDentry(ctx, &d.vfsd) + + name := state.names[i] + d.parent.dirMu.Lock() + + if d.isSynthetic() { + // Normally we don't mark invalidated dentries as deleted since + // they may still exist (but at a different path), and also for + // consistency with Linux. However, synthetic files are guaranteed + // to become unreachable if their dentries are invalidated, so + // treat their invalidation as deletion. + d.setDeleted() + d.decRefNoCaching() + *ds = appendDentry(*ds, d) + + d.parent.syntheticChildren-- + d.parent.dirents = nil + } + + // Since the dirMu was released and reacquired, re-check that the + // parent's child with this name is still the same. Do not touch it if + // it has been replaced with a different one. + if child := d.parent.children[name]; child == d { + // Invalidate dentry so it gets reloaded next time it's accessed. + delete(d.parent.children, name) + } + d.parent.dirMu.Unlock() + + return nil + } + + // The file at this path hasn't changed. Just update cached metadata. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + d.metadataMu.Unlock() + } + + return nil +} + +// revalidateStatePool caches revalidateState instances to save array +// allocations for dentries and names. +var revalidateStatePool = sync.Pool{ + New: func() interface{} { + return &revalidateState{} + }, +} + +// revalidateState keeps state related to a revalidation request. It keeps track +// of {name, dentry} list being revalidated, as well as metadata locks on the +// dentries. The list must be in ancestry order, in other words `n` must be +// `n-1` child. +type revalidateState struct { + // start is the dentry where to start the attributes search. + start *dentry + + // List of names of entries to refresh attributes. Names length must be the + // same as detries length. They are kept in separate slices because names is + // used to call File.MultiGetAttr(). + names []string + + // dentries is the list of dentries that correspond to the names above. + // dentry.metadataMu is acquired as each dentry is added to this list. + dentries []*dentry + + // locked indicates if metadata lock has been acquired on dentries. + locked bool +} + +func makeRevalidateState(start *dentry) *revalidateState { + r := revalidateStatePool.Get().(*revalidateState) + r.start = start + return r +} + +// release must be called after the caller is done with this object. It releases +// all metadata locks and resources. +func (r *revalidateState) release() { + r.reset() + revalidateStatePool.Put(r) +} + +// Preconditions: +// * d is a descendant of all dentries in r.dentries. +func (r *revalidateState) add(name string, d *dentry) { + r.names = append(r.names, name) + r.dentries = append(r.dentries, d) +} + +func (r *revalidateState) lockAllMetadata() { + for _, d := range r.dentries { + d.metadataMu.Lock() + } + r.locked = true +} + +func (r *revalidateState) popFront() *dentry { + if len(r.dentries) == 0 { + return nil + } + d := r.dentries[0] + r.dentries = r.dentries[1:] + return d +} + +// reset releases all metadata locks and resets all fields to allow this +// instance to be reused. +func (r *revalidateState) reset() { + if r.locked { + // Unlock any remaining dentries. + for _, d := range r.dentries { + d.metadataMu.Unlock() + } + r.locked = false + } + r.start = nil + r.names = r.names[:0] + r.dentries = r.dentries[:0] +} diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 3b90375b6..a81f550b1 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -460,6 +460,9 @@ func (i *inode) DecRef(ctx context.Context) { if err := unix.Close(i.hostFD); err != nil { log.Warningf("failed to close host fd %d: %v", i.hostFD, err) } + // We can't rely on fdnotifier when closing the fd, because the event may race + // with fdnotifier.RemoveFD. Instead, notify the queue explicitly. + i.queue.Notify(waiter.EventHUp | waiter.ReadableEvents | waiter.WritableEvents) }) } diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go index 31301c715..c502d8e99 100644 --- a/pkg/sentry/fsimpl/host/save_restore.go +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -68,3 +68,10 @@ func (i *inode) afterLoad() { } } } + +// afterLoad is invoked by stateify. +func (c *ConnectedEndpoint) afterLoad() { + if err := c.initFromOptions(); err != nil { + panic(fmt.Sprintf("initFromOptions failed: %v", err)) + } +} diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 60e237ac7..ca85f5601 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -39,7 +39,7 @@ import ( func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) { // Set up an external transport.Endpoint using the host fd. addr := fmt.Sprintf("hostfd:[%d]", hostFD) - e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */) + e, err := NewConnectedEndpoint(hostFD, addr) if err != nil { return nil, err.ToError() } @@ -86,7 +86,10 @@ type ConnectedEndpoint struct { // for restoring them. func (c *ConnectedEndpoint) init() *syserr.Error { c.InitRefs() + return c.initFromOptions() +} +func (c *ConnectedEndpoint) initFromOptions() *syserr.Error { family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN) if err != nil { return syserr.FromError(err) @@ -123,7 +126,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { // The caller is responsible for calling Init(). Additionaly, Release needs to // be called twice because ConnectedEndpoint is both a transport.Receiver and // transport.ConnectedEndpoint. -func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) { +func NewConnectedEndpoint(hostFD int, addr string) (*ConnectedEndpoint, *syserr.Error) { e := ConnectedEndpoint{ fd: hostFD, addr: addr, @@ -330,8 +333,16 @@ func (c *ConnectedEndpoint) CloseUnread() {} // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { - // gVisor does not permit setting of SO_SNDBUF for host backed unix domain - // sockets. + // gVisor does not permit setting of SO_SNDBUF for host backed unix + // domain sockets. + return atomic.LoadInt64(&c.sndbuf) +} + +// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize. +func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_RCVBUF for host backed unix + // domain sockets. Receive buffer does not have any effect for unix + // sockets and we claim to be the same as send buffer. return atomic.LoadInt64(&c.sndbuf) } diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 65054b0ea..84b1c3745 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -25,8 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// DynamicBytesFile implements kernfs.Inode and represents a read-only -// file whose contents are backed by a vfs.DynamicBytesSource. +// DynamicBytesFile implements kernfs.Inode and represents a read-only file +// whose contents are backed by a vfs.DynamicBytesSource. If data additionally +// implements vfs.WritableDynamicBytesSource, the file also supports dispatching +// writes to the implementer, but note that this will not update the source data. // // Must be instantiated with NewDynamicBytesFile or initialized with Init // before first use. @@ -40,7 +42,9 @@ type DynamicBytesFile struct { InodeNotSymlink locks vfs.FileLocks - data vfs.DynamicBytesSource + // data can additionally implement vfs.WritableDynamicBytesSource to support + // writes. + data vfs.DynamicBytesSource } var _ Inode = (*DynamicBytesFile)(nil) diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index badca4d9f..f50b0fb08 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -612,16 +612,24 @@ afterTrailingSymlink: // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - fs.mu.RLock() defer fs.processDeferredDecRefs(ctx) - defer fs.mu.RUnlock() + + fs.mu.RLock() d, err := fs.walkExistingLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() return "", err } if !d.isSymlink() { + fs.mu.RUnlock() return "", syserror.EINVAL } + + // Inode.Readlink() cannot be called holding fs locks. + d.IncRef() + defer d.DecRef(ctx) + fs.mu.RUnlock() + return d.inode.Readlink(ctx, rp.Mount()) } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 565d723f0..6f699c9cd 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -508,6 +509,15 @@ func (d *Dentry) Inode() Inode { return d.inode } +// FSLocalPath returns an absolute path to d, relative to the root of its +// filesystem. +func (d *Dentry) FSLocalPath() string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + b.PrependByte('/') + return b.String() +} + // The Inode interface maps filesystem-level operations that operate on paths to // equivalent operations on specific filesystem nodes. // @@ -524,6 +534,9 @@ func (d *Dentry) Inode() Inode { // - Checking that dentries passed to methods are of the appropriate file type. // - Checking permissions. // +// Inode functions may be called holding filesystem wide locks and are not +// allowed to call vfs functions that may reenter, unless otherwise noted. +// // Specific responsibilities of implementations are documented below. type Inode interface { // Methods related to reference counting. A generic implementation is @@ -670,6 +683,9 @@ type inodeDirectory interface { type inodeSymlink interface { // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. + // + // Readlink is called with no kernfs locks held, so it may reenter if needed + // to resolve symlink targets. Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) // Getlink returns the target of a symbolic link, as used by path diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 254a8b062..ce8f55b1f 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -86,13 +86,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) - var cgroups map[string]string + var fakeCgroupControllers map[string]string if opts.InternalData != nil { data := opts.InternalData.(*InternalData) - cgroups = data.Cgroups + fakeCgroupControllers = data.Cgroups } - inode := procfs.newTasksInode(ctx, k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, fakeCgroupControllers) var dentry kernfs.Dentry dentry.InitRoot(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index fea138f93..d05cc1508 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -47,7 +47,7 @@ type taskInode struct { var _ kernfs.Inode = (*taskInode)(nil) -func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { +func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, fakeCgroupControllers map[string]string) (kernfs.Inode, error) { if task.ExitState() == kernel.TaskExitDead { return nil, syserror.ESRCH } @@ -82,10 +82,12 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { - contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers) + contents["task"] = fs.newSubtasks(ctx, task, pidns, fakeCgroupControllers) } - if len(cgroupControllers) > 0 { - contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) + if len(fakeCgroupControllers) > 0 { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newFakeCgroupData(fakeCgroupControllers)) + } else { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskCgroupData{task: task}) } taskInode := &taskInode{task: task} @@ -226,11 +228,14 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData { return &ioData{ioUsage: t} } -// newCgroupData creates inode that shows cgroup information. -// From man 7 cgroups: "For each cgroup hierarchy of which the process is a -// member, there is one entry containing three colon-separated fields: -// hierarchy-ID:controller-list:cgroup-path" -func newCgroupData(controllers map[string]string) dynamicInode { +// newFakeCgroupData creates an inode that shows fake cgroup +// information passed in as mount options. From man 7 cgroups: "For +// each cgroup hierarchy of which the process is a member, there is +// one entry containing three colon-separated fields: +// hierarchy-ID:controller-list:cgroup-path" +// +// TODO(b/182488796): Remove once all users adopt cgroupfs. +func newFakeCgroupData(controllers map[string]string) dynamicInode { var buf bytes.Buffer // The hierarchy ids must be positive integers (for cgroup v1), but the diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 02bf74dbc..4718fac7a 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -221,6 +221,8 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) defer file.DecRef(ctx) root := vfs.RootFromContext(ctx) defer root.DecRef(ctx) + + // Note: it's safe to reenter kernfs from Readlink if needed to resolve path. return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 85909d551..b294dfd6a 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -1100,3 +1100,32 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } + +// taskCgroupData generates data for /proc/[pid]/cgroup. +// +// +stateify savable +type taskCgroupData struct { + dynamicBytesFileSetAttr + task *kernel.Task +} + +var _ dynamicInode = (*taskCgroupData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *taskCgroupData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // When a task is existing on Linux, a task's cgroup set is cleared and + // reset to the initial cgroup set, which is essentially the set of root + // cgroups. Because of this, the /proc/<pid>/cgroup file is always readable + // on Linux throughout a task's lifetime. + // + // The sentry removes tasks from cgroups during the exit process, but + // doesn't move them into an initial cgroup set, so partway through task + // exit this file show a task is in no cgroups, which is incorrect. Instead, + // once a task has left its cgroups, we return an error. + if d.task.ExitState() >= kernel.TaskExitInitiated { + return syserror.ESRCH + } + + d.task.GenerateProcTaskCgroup(buf) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index fdc580610..cf905fae4 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -54,17 +54,18 @@ type tasksInode struct { // '/proc/self' and '/proc/thread-self' have custom directory offsets in // Linux. So handle them outside of OrderedChildren. - // cgroupControllers is a map of controller name to directory in the + // fakeCgroupControllers is a map of controller name to directory in the // cgroup hierarchy. These controllers are immutable and will be listed // in /proc/pid/cgroup if not nil. - cgroupControllers map[string]string + fakeCgroupControllers map[string]string } var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ + "cmdline": fs.newInode(ctx, root, 0444, &cmdLineData{}), "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), @@ -76,11 +77,16 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), "version": fs.newInode(ctx, root, 0444, &versionData{}), } + // If fakeCgroupControllers are provided, don't create a cgroupfs backed + // /proc/cgroup as it will not match the fake controllers. + if len(fakeCgroupControllers) == 0 { + contents["cgroups"] = fs.newInode(ctx, root, 0444, &cgroupsData{}) + } inode := &tasksInode{ - pidns: pidns, - fs: fs, - cgroupControllers: cgroupControllers, + pidns: pidns, + fs: fs, + fakeCgroupControllers: fakeCgroupControllers, } inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.InitRefs() @@ -118,7 +124,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err return nil, syserror.ENOENT } - return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers) + return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index f0029cda6..045ed7a2d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -336,15 +336,6 @@ var _ dynamicInode = (*versionData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { - k := kernel.KernelFromContext(ctx) - init := k.GlobalInit() - if init == nil { - // Attempted to read before the init Task is created. This can - // only occur during startup, which should never need to read - // this file. - panic("Attempted to read version before initial Task is available") - } - // /proc/version takes the form: // // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST) @@ -364,7 +355,7 @@ func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { // FIXME(mpratt): Using Version from the init task SyscallTable // disregards the different version a task may have (e.g., in a uts // namespace). - ver := init.Leader().SyscallTable().Version + ver := kernelVersion(ctx) fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) return nil } @@ -384,3 +375,47 @@ func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error k.VFS().GenerateProcFilesystems(buf) return nil } + +// cgroupsData backs /proc/cgroups. +// +// +stateify savable +type cgroupsData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cgroupsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + r := kernel.KernelFromContext(ctx).CgroupRegistry() + r.GenerateProcCgroups(buf) + return nil +} + +// cmdLineData backs /proc/cmdline. +// +// +stateify savable +type cmdLineData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cmdLineData)(nil) + +// Generate implements vfs.DynamicByteSource.Generate. +func (*cmdLineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "BOOT_IMAGE=/vmlinuz-%s-gvisor quiet\n", kernelVersion(ctx).Release) + return nil +} + +// kernelVersion returns the kernel version. +func kernelVersion(ctx context.Context) kernel.Version { + k := kernel.KernelFromContext(ctx) + init := k.GlobalInit() + if init == nil { + // Attempted to read before the init Task is created. This can + // only occur during startup, which should never need to read + // this file. + panic("Attempted to read version before initial Task is available") + } + return init.Leader().SyscallTable().Version +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index d6f076cd6..e534fbca8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -47,6 +47,7 @@ var ( var ( tasksStaticFiles = map[string]testutil.DirentType{ + "cmdline": linux.DT_REG, "cpuinfo": linux.DT_REG, "filesystems": linux.DT_REG, "loadavg": linux.DT_REG, diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1d9280dae..14eb10dcd 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -122,11 +122,11 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { - // If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs - // should be mounted at debug/, but for our purposes, it is sufficient to - // keep it in sys. + // Set up /sys/kernel/debug/kcov. Technically, debugfs should be + // mounted at debug/, but for our purposes, it is sufficient to keep it + // in sys. var children map[string]kernfs.Inode - if coverage.KcovAvailable() { + if coverage.KcovSupported() { log.Debugf("Set up /sys/kernel/debug/kcov") children = map[string]kernfs.Inode{ "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index b3f9d1010..c766164c7 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -17,6 +17,7 @@ go_library( "//pkg/fspath", "//pkg/hostarch", "//pkg/memutil", + "//pkg/metric", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 807e4f44a..33e52ce64 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -62,6 +63,8 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } + metric.CreateSentryMetrics() + kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index cd849e87e..c45bddff6 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -488,6 +488,7 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { file := fd.inode().impl.(*regularFile) + opts.SentryOwnedContent = true return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts) } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 2da251233..d473a922d 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -18,10 +18,12 @@ go_library( "//pkg/marshal/primitive", "//pkg/merkletree", "//pkg/refsvfs2", + "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 6cb1a23e0..3582d14c9 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -168,10 +168,6 @@ afterSymlink: // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -// -// TODO(b/166474175): Investigate all possible errors returned in this -// function, and make sure we differentiate all errors that indicate unexpected -// modifications to the file system from the ones that are not harmful. func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() @@ -200,7 +196,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -209,7 +205,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -223,12 +219,14 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err } + defer parentMerkleFD.DecRef(ctx) + // dataSize is the size of raw data for the Merkle tree. For a file, // dataSize is the size of the whole file. For a directory, dataSize is // the size of all its children's hashes. @@ -241,7 +239,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -251,7 +249,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := FileReadWriteSeeker{ @@ -264,7 +262,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -276,16 +274,15 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi var buf bytes.Buffer parent.hashMu.RLock() _, err = merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, - Children: parent.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + Children: parent.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), @@ -294,7 +291,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi }) parent.hashMu.RUnlock() if err != nil && err != io.EOF { - return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. @@ -331,19 +328,21 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err } + defer fd.DecRef(ctx) + merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ Name: merkleSizeXattr, Size: sizeOfStringInt32, }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -351,7 +350,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } if d.isDir() && len(d.childrenNames) == 0 { @@ -361,14 +360,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) } if err != nil { return err } childrenOffset, err := strconv.Atoi(childrenOffString) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) } childrenSizeString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ @@ -377,23 +376,23 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) } if err != nil { return err } childrenSize, err := strconv.Atoi(childrenSizeString) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) } childrenNames := make([]byte, childrenSize) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(childrenOffset), vfs.ReadOptions{}); err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) } if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) } } @@ -405,15 +404,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry var buf bytes.Buffer d.hashMu.RLock() params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - Children: d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + Children: d.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: 0, // Set read size to 0 so only the metadata is verified. @@ -438,7 +436,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID @@ -471,7 +469,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // The file was previously accessed. If the // file does not exist now, it indicates an // unexpected modification to the file system. - return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) } if err != nil { return nil, err @@ -483,7 +481,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // does not exist now, it indicates an unexpected // modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) } if err != nil { return nil, err @@ -553,8 +551,8 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } childVD, err := parent.getLowerAt(ctx, vfsObj, name) - if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) + if parent.verityEnabled() && err == syserror.ENOENT { + return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) } if err != nil { return nil, err @@ -565,30 +563,31 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, defer childVD.DecRef(ctx) childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - if err == syserror.ENOENT { - if !fs.allowRuntimeEnable { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) - } - childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(merklePrefix + name), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT, - Mode: 0644, - }) - if err != nil { - return nil, err - } - childMerkleFD.DecRef(ctx) - childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - if err != nil { + if err != nil { + if err == syserror.ENOENT { + if parent.verityEnabled() { + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) + } + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(merklePrefix + name), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + if err != nil { + return nil, err + } + } else { return nil, err } } - if err != nil && err != syserror.ENOENT { - return nil, err - } // Clear the Merkle tree file if they are to be generated at runtime. // TODO(b/182315468): Optimize the Merkle tree generate process to @@ -632,8 +631,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childVD.IncRef() childMerkleVD.IncRef() - parent.IncRef() - child.parent = parent child.name = name child.mode = uint32(stat.Mode) @@ -657,6 +654,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } } + parent.IncRef() + child.parent = parent + return child, nil } @@ -855,7 +855,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -878,7 +878,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -903,7 +903,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -921,7 +921,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } @@ -985,8 +985,6 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } // StatAt implements vfs.FilesystemImpl.StatAt. -// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should -// be verified. func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index a7d92a878..31d34ef60 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -34,6 +34,8 @@ package verity import ( + "bytes" + "encoding/hex" "encoding/json" "fmt" "math" @@ -44,19 +46,20 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -95,14 +98,18 @@ const ( ) var ( - // action specifies the action towards detected violation. - action ViolationAction - // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. verityMu sync.RWMutex ) +// Mount option names for verityfs. +const ( + moptLowerPath = "lower_path" + moptRootHash = "root_hash" + moptRootName = "root_name" +) + // HashAlgorithm is a type specifying the algorithm used to hash the file // content. type HashAlgorithm int @@ -169,6 +176,12 @@ type filesystem struct { // system. alg HashAlgorithm + // action specifies the action towards detected violation. + action ViolationAction + + // opts is the string mount options passed to opts.Data. + opts string + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -191,9 +204,6 @@ type filesystem struct { // // +stateify savable type InternalFilesystemOptions struct { - // RootMerkleFileName is the name of the verity root Merkle tree file. - RootMerkleFileName string - // LowerName is the name of the filesystem wrapped by verity fs. LowerName string @@ -201,9 +211,6 @@ type InternalFilesystemOptions struct { // system. Alg HashAlgorithm - // RootHash is the root hash of the overall verity file system. - RootHash []byte - // AllowRuntimeEnable specifies whether the verity file system allows // enabling verification for files (i.e. building Merkle trees) during // runtime. @@ -228,8 +235,8 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In ErrorOnViolation // mode, it returns EIO, otherwise it panic. -func alertIntegrityViolation(msg string) error { - if action == ErrorOnViolation { +func (fs *filesystem) alertIntegrityViolation(msg string) error { + if fs.action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -237,28 +244,99 @@ func alertIntegrityViolation(msg string) error { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + mopts := vfs.GenericParseMountOptions(opts.Data) + var rootHash []byte + if encodedRootHash, ok := mopts[moptRootHash]; ok { + delete(mopts, moptRootHash) + hash, err := hex.DecodeString(encodedRootHash) + if err != nil { + ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err) + return nil, nil, syserror.EINVAL + } + rootHash = hash + } + var lowerPathname string + if path, ok := mopts[moptLowerPath]; ok { + delete(mopts, moptLowerPath) + lowerPathname = path + } + rootName := "root" + if root, ok := mopts[moptRootName]; ok { + delete(mopts, moptRootName) + rootName = root + } + + // Check for unparsed options. + if len(mopts) != 0 { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + // Handle internal options. iopts, ok := opts.InternalData.(InternalFilesystemOptions) - if !ok { + if len(lowerPathname) == 0 && !ok { ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - action = iopts.Action - - // Mount the lower file system. The lower file system is wrapped inside - // verity, and should not be exposed or connected. - mopts := &vfs.MountOptions{ - GetFilesystemOptions: iopts.LowerGetFSOptions, - InternalMount: true, + if len(lowerPathname) != 0 { + if ok { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path") + return nil, nil, syserror.EINVAL + } + iopts = InternalFilesystemOptions{ + AllowRuntimeEnable: len(rootHash) == 0, + Action: ErrorOnViolation, + } } - mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts) - if err != nil { - return nil, nil, err + + var lowerMount *vfs.Mount + var mountedLowerVD vfs.VirtualDentry + // Use an existing mount if lowerPath is provided. + if len(lowerPathname) != 0 { + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + lowerPath := fspath.Parse(lowerPathname) + if !lowerPath.Absolute { + ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname) + return nil, nil, syserror.EINVAL + } + var err error + mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: lowerPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("verity.FilesystemType.GetFilesystem: failed to resolve lower_path %q: %v", lowerPathname, err) + return nil, nil, err + } + lowerMount = mountedLowerVD.Mount() + defer mountedLowerVD.DecRef(ctx) + } else { + // Mount the lower file system. The lower file system is wrapped inside + // verity, and should not be exposed or connected. + mountOpts := &vfs.MountOptions{ + GetFilesystemOptions: iopts.LowerGetFSOptions, + InternalMount: true, + } + mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mountOpts) + if err != nil { + return nil, nil, err + } + lowerMount = mnt } fs := &filesystem{ creds: creds.Fork(), alg: iopts.Alg, - lowerMount: mnt, + lowerMount: lowerMount, + action: iopts.Action, + opts: opts.Data, allowRuntimeEnable: iopts.AllowRuntimeEnable, } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -266,11 +344,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Construct the root dentry. d := fs.newDentry() d.refs = 1 - lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root()) + lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root()) lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + rootName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -311,7 +389,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") + return nil, nil, fs.alertIntegrityViolation("Failed to find root Merkle file") } // Clear the Merkle tree file if they are to be generated at runtime. @@ -350,9 +428,15 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID - d.hash = make([]byte, len(iopts.RootHash)) d.childrenNames = make(map[string]struct{}) + d.hashMu.Lock() + d.hash = make([]byte, len(rootHash)) + copy(d.hash, rootHash) + d.hashMu.Unlock() + + fs.rootDentry = d + if !d.isDir() { ctx.Warningf("verity root must be a directory") return nil, nil, syserror.EINVAL @@ -368,7 +452,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Size: sizeOfStringInt32, }) if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) } if err != nil { return nil, nil, err @@ -376,7 +460,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt off, err := strconv.Atoi(offString) if err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) } sizeString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{ @@ -387,14 +471,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Size: sizeOfStringInt32, }) if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) } if err != nil { return nil, nil, err } size, err := strconv.Atoi(sizeString) if err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) } lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ @@ -404,19 +488,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) } if err != nil { return nil, nil, err } + defer lowerMerkleFD.DecRef(ctx) + childrenNames := make([]byte, size) if _, err := lowerMerkleFD.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(off), vfs.ReadOptions{}); err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) } if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) } if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { @@ -424,13 +510,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } - d.hashMu.Lock() - copy(d.hash, iopts.RootHash) - d.hashMu.Unlock() d.vfsd.Init(d) - fs.rootDentry = d - return &fs.vfsfs, &d.vfsd, nil } @@ -441,7 +522,7 @@ func (fs *filesystem) Release(ctx context.Context) { // MountOptions implements vfs.FilesystemImpl.MountOptions. func (fs *filesystem) MountOptions() string { - return "" + return fs.opts } // dentry implements vfs.DentryImpl. @@ -722,6 +803,10 @@ type fileDescription struct { // underlying file system. lowerFD *vfs.FileDescription + // lowerMappable is the memmap.Mappable corresponding to this file in the + // underlying file system. + lowerMappable memmap.Mappable + // merkleReader is the read-only FileDescription corresponding to the // Merkle tree file in the underlying file system. merkleReader *vfs.FileDescription @@ -755,7 +840,6 @@ func (fd *fileDescription) Release(ctx context.Context) { // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - // TODO(b/162788573): Add integrity check for metadata. stat, err := fd.lowerFD.Stat(ctx, opts) if err != nil { return linux.Statx{}, err @@ -794,7 +878,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa // Verify that the child is expected. if dirent.Name != "." && dirent.Name != ".." { if _, ok := fd.d.childrenNames[dirent.Name]; !ok { - return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) + return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) } } } @@ -808,7 +892,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa // The result should contain all children plus "." and "..". if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 { - return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) + return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) } for fd.off < int64(len(ds)) { @@ -875,10 +959,9 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui } params := &merkletree.GenerateParams{ - TreeReader: &merkleReader, - TreeWriter: &merkleWriter, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + TreeReader: &merkleReader, + TreeWriter: &merkleWriter, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), Name: fd.d.name, Mode: uint32(stat.Mode), @@ -980,7 +1063,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") + return 0, fd.d.fs.alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } hash, dataSize, err := fd.generateMerkleLocked(ctx) @@ -1053,7 +1136,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest hosta if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") + return 0, fd.d.fs.alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -1107,8 +1190,6 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. case linux.FS_IOC_GETFLAGS: return fd.verityFlags(ctx, args[2].Pointer()) default: - // TODO(b/169682228): Investigate which ioctl commands should - // be allowed. return 0, syserror.ENOSYS } } @@ -1143,7 +1224,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -1153,7 +1234,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } dataReader := FileReadWriteSeeker{ @@ -1168,16 +1249,15 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of fd.d.hashMu.RLock() n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), @@ -1186,7 +1266,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of }) fd.d.hashMu.RUnlock() if err != nil { - return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } @@ -1201,6 +1281,24 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op return 0, syserror.EROFS } +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + if err := fd.lowerFD.ConfigureMMap(ctx, opts); err != nil { + return err + } + fd.lowerMappable = opts.Mappable + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef(ctx) + opts.MappingIdentity = nil + } + + // Check if mmap is allowed on the lower filesystem. + if !opts.SentryOwnedContent { + return syserror.ENODEV + } + return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts) +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) @@ -1226,6 +1324,115 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t return fd.lowerFD.TestPOSIX(ctx, uid, t, r) } +// Translate implements memmap.Mappable.Translate. +func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { + ts, err := fd.lowerMappable.Translate(ctx, required, optional, at) + if err != nil { + return nil, err + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return nil, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + merkleReader := FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + for _, t := range ts { + // Content integrity relies on sentry owning the backing data. MapInternal is guaranteed + // to fetch sentry owned memory because we disallow verity mmaps otherwise. + ims, err := t.File.MapInternal(memmap.FileRange{t.Offset, t.Offset + t.Source.Length()}, hostarch.Read) + if err != nil { + return nil, err + } + dataReader := mmapReadSeeker{ims, t.Source.Start} + var buf bytes.Buffer + _, err = merkletree.Verify(&merkletree.VerifyParams{ + Out: &buf, + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), + ReadOffset: int64(t.Source.Start), + ReadSize: int64(t.Source.Length()), + Expected: fd.d.hash, + DataAndTreeInSameFile: false, + }) + if err != nil { + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + } + } + return ts, err +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (fd *fileDescription) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.AddMapping(ctx, ms, ar, offset, writable) +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (fd *fileDescription) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { + fd.lowerMappable.RemoveMapping(ctx, ms, ar, offset, writable) +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (fd *fileDescription) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable) +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (fd *fileDescription) InvalidateUnsavable(context.Context) error { + return nil +} + +// mmapReadSeeker is a helper struct used by fileDescription.Translate to pass +// a safemem.BlockSeq pointing to the mapped region as io.ReaderAt. +type mmapReadSeeker struct { + safemem.BlockSeq + Offset uint64 +} + +// ReadAt implements io.ReaderAt.ReadAt. off is the offset into the mapped file. +func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) { + bs := r.BlockSeq + // Adjust the offset into the mapped file to get the offset into the internally + // mapped region. + readOffset := off - int64(r.Offset) + if readOffset < 0 { + return 0, syserror.EINVAL + } + bs.DropFirst64(uint64(readOffset)) + view := bs.TakeFirst64(uint64(len(p))) + dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p)) + n, err := safemem.CopySeq(dst, view) + return int(n), err +} + // FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as // io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. type FileReadWriteSeeker struct { diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 57bd65202..5c78a0019 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -89,10 +89,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, AllowUserMount: true, }) + data := "root_name=" + rootMerkleFilename mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: data, InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", Alg: hashAlg, AllowRuntimeEnable: true, diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e9eb89378..a1ec6daab 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -141,6 +141,7 @@ go_library( srcs = [ "abstract_socket_namespace.go", "aio.go", + "cgroup.go", "context.go", "fd_table.go", "fd_table_refs.go", @@ -178,6 +179,7 @@ go_library( "task.go", "task_acct.go", "task_block.go", + "task_cgroup.go", "task_clone.go", "task_context.go", "task_exec.go", @@ -241,6 +243,7 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/fsimpl/timerfd", diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go new file mode 100644 index 000000000..1f1c63f37 --- /dev/null +++ b/pkg/sentry/kernel/cgroup.go @@ -0,0 +1,281 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +// InvalidCgroupHierarchyID indicates an uninitialized hierarchy ID. +const InvalidCgroupHierarchyID uint32 = 0 + +// CgroupControllerType is the name of a cgroup controller. +type CgroupControllerType string + +// CgroupController is the common interface to cgroup controllers available to +// the entire sentry. The controllers themselves are defined by cgroupfs. +// +// Callers of this interface are often unable access synchronization needed to +// ensure returned values remain valid. Some of values returned from this +// interface are thus snapshots in time, and may become stale. This is ok for +// many callers like procfs. +type CgroupController interface { + // Returns the type of this cgroup controller (ex "memory", "cpu"). Returned + // value is valid for the lifetime of the controller. + Type() CgroupControllerType + + // Hierarchy returns the ID of the hierarchy this cgroup controller is + // attached to. Returned value is valid for the lifetime of the controller. + HierarchyID() uint32 + + // Filesystem returns the filesystem this controller is attached to. + // Returned value is valid for the lifetime of the controller. + Filesystem() *vfs.Filesystem + + // RootCgroup returns the root cgroup for this controller. Returned value is + // valid for the lifetime of the controller. + RootCgroup() Cgroup + + // NumCgroups returns the number of cgroups managed by this controller. + // Returned value is a snapshot in time. + NumCgroups() uint64 + + // Enabled returns whether this controller is enabled. Returned value is a + // snapshot in time. + Enabled() bool +} + +// Cgroup represents a named pointer to a cgroup in cgroupfs. When a task enters +// a cgroup, it holds a reference on the underlying dentry pointing to the +// cgroup. +// +// +stateify savable +type Cgroup struct { + *kernfs.Dentry + CgroupImpl +} + +func (c *Cgroup) decRef() { + c.Dentry.DecRef(context.Background()) +} + +// Path returns the absolute path of c, relative to its hierarchy root. +func (c *Cgroup) Path() string { + return c.FSLocalPath() +} + +// HierarchyID returns the id of the hierarchy that contains this cgroup. +func (c *Cgroup) HierarchyID() uint32 { + // Note: a cgroup is guaranteed to have at least one controller. + return c.Controllers()[0].HierarchyID() +} + +// CgroupImpl is the common interface to cgroups. +type CgroupImpl interface { + Controllers() []CgroupController + Enter(t *Task) + Leave(t *Task) +} + +// hierarchy represents a cgroupfs filesystem instance, with a unique set of +// controllers attached to it. Multiple cgroupfs mounts may reference the same +// hierarchy. +// +// +stateify savable +type hierarchy struct { + id uint32 + // These are a subset of the controllers in CgroupRegistry.controllers, + // grouped here by hierarchy for conveninent lookup. + controllers map[CgroupControllerType]CgroupController + // fs is not owned by hierarchy. The FS is responsible for unregistering the + // hierarchy on destruction, which removes this association. + fs *vfs.Filesystem +} + +func (h *hierarchy) match(ctypes []CgroupControllerType) bool { + if len(ctypes) != len(h.controllers) { + return false + } + for _, ty := range ctypes { + if _, ok := h.controllers[ty]; !ok { + return false + } + } + return true +} + +// CgroupRegistry tracks the active set of cgroup controllers on the system. +// +// +stateify savable +type CgroupRegistry struct { + // lastHierarchyID is the id of the last allocated cgroup hierarchy. Valid + // ids are from 1 to math.MaxUint32. Must be accessed through atomic ops. + // + lastHierarchyID uint32 + + mu sync.Mutex `state:"nosave"` + + // controllers is the set of currently known cgroup controllers on the + // system. Protected by mu. + // + // +checklocks:mu + controllers map[CgroupControllerType]CgroupController + + // hierarchies is the active set of cgroup hierarchies. Protected by mu. + // + // +checklocks:mu + hierarchies map[uint32]hierarchy +} + +func newCgroupRegistry() *CgroupRegistry { + return &CgroupRegistry{ + controllers: make(map[CgroupControllerType]CgroupController), + hierarchies: make(map[uint32]hierarchy), + } +} + +// nextHierarchyID returns a newly allocated, unique hierarchy ID. +func (r *CgroupRegistry) nextHierarchyID() (uint32, error) { + if hid := atomic.AddUint32(&r.lastHierarchyID, 1); hid != 0 { + return hid, nil + } + return InvalidCgroupHierarchyID, fmt.Errorf("cgroup hierarchy ID overflow") +} + +// FindHierarchy returns a cgroup filesystem containing exactly the set of +// controllers named in names. If no such FS is found, FindHierarchy return +// nil. FindHierarchy takes a reference on the returned FS, which is transferred +// to the caller. +func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Filesystem { + r.mu.Lock() + defer r.mu.Unlock() + + for _, h := range r.hierarchies { + if h.match(ctypes) { + h.fs.IncRef() + return h.fs + } + } + + return nil +} + +// Register registers the provided set of controllers with the registry as a new +// hierarchy. If any controller is already registered, the function returns an +// error without modifying the registry. The hierarchy can be later referenced +// by the returned id. +func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if len(cs) == 0 { + return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers") + } + + for _, c := range cs { + if _, ok := r.controllers[c.Type()]; ok { + return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy") + } + } + + hid, err := r.nextHierarchyID() + if err != nil { + return hid, err + } + + h := hierarchy{ + id: hid, + controllers: make(map[CgroupControllerType]CgroupController), + fs: cs[0].Filesystem(), + } + for _, c := range cs { + n := c.Type() + r.controllers[n] = c + h.controllers[n] = c + } + r.hierarchies[hid] = h + return hid, nil +} + +// Unregister removes a previously registered hierarchy from the registry. If +// the controller was not previously registered, Unregister is a no-op. +func (r *CgroupRegistry) Unregister(hid uint32) { + r.mu.Lock() + defer r.mu.Unlock() + + if h, ok := r.hierarchies[hid]; ok { + for name, _ := range h.controllers { + delete(r.controllers, name) + } + delete(r.hierarchies, hid) + } +} + +// computeInitialGroups takes a reference on each of the returned cgroups. The +// caller takes ownership of this returned reference. +func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[Cgroup]struct{} { + r.mu.Lock() + defer r.mu.Unlock() + + ctlSet := make(map[CgroupControllerType]CgroupController) + cgset := make(map[Cgroup]struct{}) + + // Remember controllers from the inherited cgroups set... + for cg, _ := range inherit { + cg.IncRef() // Ref transferred to caller. + for _, ctl := range cg.Controllers() { + ctlSet[ctl.Type()] = ctl + cgset[cg] = struct{}{} + } + } + + // ... and add the root cgroups of all the missing controllers. + for name, ctl := range r.controllers { + if _, ok := ctlSet[name]; !ok { + cg := ctl.RootCgroup() + cg.IncRef() // Ref transferred to caller. + cgset[cg] = struct{}{} + } + } + return cgset +} + +// GenerateProcCgroups writes the contents of /proc/cgroups to buf. +func (r *CgroupRegistry) GenerateProcCgroups(buf *bytes.Buffer) { + r.mu.Lock() + entries := make([]string, 0, len(r.controllers)) + for _, c := range r.controllers { + en := 0 + if c.Enabled() { + en = 1 + } + entries = append(entries, fmt.Sprintf("%s\t%d\t%d\t%d\n", c.Type(), c.HierarchyID(), c.NumCgroups(), en)) + } + r.mu.Unlock() + + sort.Strings(entries) + fmt.Fprint(buf, "#subsys_name\thierarchy\tnum_cgroups\tenabled\n") + for _, e := range entries { + fmt.Fprint(buf, e) + } +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 43065b45a..e6e9da898 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -294,6 +294,11 @@ type Kernel struct { // YAMAPtraceScope is the current level of YAMA ptrace restrictions. YAMAPtraceScope int32 + + // cgroupRegistry contains the set of active cgroup controllers on the + // system. It is controller by cgroupfs. Nil if cgroupfs is unavailable on + // the system. + cgroupRegistry *CgroupRegistry } // InitKernelArgs holds arguments to Init. @@ -438,6 +443,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.socketMount = socketMount k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) + + k.cgroupRegistry = newCgroupRegistry() } return nil } @@ -1815,6 +1822,11 @@ func (k *Kernel) SocketMount() *vfs.Mount { return k.socketMount } +// CgroupRegistry returns the cgroup registry. +func (k *Kernel) CgroupRegistry() *CgroupRegistry { + return k.cgroupRegistry +} + // Release releases resources owned by k. // // Precondition: This should only be called after the kernel is fully @@ -1831,3 +1843,43 @@ func (k *Kernel) Release() { k.timekeeper.Destroy() k.vdso.Release(ctx) } + +// PopulateNewCgroupHierarchy moves all tasks into a newly created cgroup +// hierarchy. +// +// Precondition: root must be a new cgroup with no tasks. This implies the +// controllers for root are also new and currently manage no task, which in turn +// implies the new cgroup can be populated without migrating tasks between +// cgroups. +func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.exitState != TaskExitNone { + return + } + t.mu.Lock() + t.enterCgroupLocked(root) + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} + +// ReleaseCgroupHierarchy moves all tasks out of all cgroups belonging to the +// hierarchy with the provided id. This is intended for use during hierarchy +// teardown, as otherwise the tasks would be orphaned w.r.t to some controllers. +func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.exitState != TaskExitNone { + return + } + t.mu.Lock() + for cg, _ := range t.cgroups { + if cg.HierarchyID() == hid { + t.leaveCgroupLocked(cg) + } + } + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 399985039..be1371855 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -587,6 +587,12 @@ type Task struct { // // kcov is exclusive to the task goroutine. kcov *Kcov + + // cgroups is the set of cgroups this task belongs to. This may be empty if + // no cgroup controllers are enabled. Protected by mu. + // + // +checklocks:mu + cgroups map[Cgroup]struct{} } func (t *Task) savePtraceTracer() *Task { diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go new file mode 100644 index 000000000..25d2504fa --- /dev/null +++ b/pkg/sentry/kernel/task_cgroup.go @@ -0,0 +1,138 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "strings" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/syserror" +) + +// EnterInitialCgroups moves t into an initial set of cgroups. +// +// Precondition: t isn't in any cgroups yet, t.cgs is empty. +// +// +checklocksignore parent.mu is conditionally acquired. +func (t *Task) EnterInitialCgroups(parent *Task) { + var inherit map[Cgroup]struct{} + if parent != nil { + parent.mu.Lock() + defer parent.mu.Unlock() + inherit = parent.cgroups + } + joinSet := t.k.cgroupRegistry.computeInitialGroups(inherit) + + t.mu.Lock() + defer t.mu.Unlock() + // Transfer ownership of joinSet refs to the task's cgset. + t.cgroups = joinSet + for c, _ := range t.cgroups { + // Since t isn't in any cgroup yet, we can skip the check against + // existing cgroups. + c.Enter(t) + } +} + +// EnterCgroup moves t into c. +func (t *Task) EnterCgroup(c Cgroup) error { + newControllers := make(map[CgroupControllerType]struct{}) + for _, ctl := range c.Controllers() { + newControllers[ctl.Type()] = struct{}{} + } + + t.mu.Lock() + defer t.mu.Unlock() + + for oldCG, _ := range t.cgroups { + for _, oldCtl := range oldCG.Controllers() { + if _, ok := newControllers[oldCtl.Type()]; ok { + // Already in a cgroup with the same controller as one of the + // new ones. Requires migration between cgroups. + // + // TODO(b/183137098): Implement cgroup migration. + log.Warningf("Cgroup migration is not implemented") + return syserror.EBUSY + } + } + } + + // No migration required. + t.enterCgroupLocked(c) + + return nil +} + +// +checklocks:t.mu +func (t *Task) enterCgroupLocked(c Cgroup) { + c.IncRef() + t.cgroups[c] = struct{}{} + c.Enter(t) +} + +// LeaveCgroups removes t out from all its cgroups. +func (t *Task) LeaveCgroups() { + t.mu.Lock() + defer t.mu.Unlock() + for c, _ := range t.cgroups { + t.leaveCgroupLocked(c) + } +} + +// +checklocks:t.mu +func (t *Task) leaveCgroupLocked(c Cgroup) { + c.Leave(t) + delete(t.cgroups, c) + c.decRef() +} + +// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to +// format a cgroup for display. +type taskCgroupEntry struct { + hierarchyID uint32 + controllers string + path string +} + +// GenerateProcTaskCgroup writes the contents of /proc/<pid>/cgroup for t to buf. +func (t *Task) GenerateProcTaskCgroup(buf *bytes.Buffer) { + t.mu.Lock() + defer t.mu.Unlock() + + cgEntries := make([]taskCgroupEntry, 0, len(t.cgroups)) + for c, _ := range t.cgroups { + ctls := c.Controllers() + ctlNames := make([]string, 0, len(ctls)) + for _, ctl := range ctls { + ctlNames = append(ctlNames, string(ctl.Type())) + } + + cgEntries = append(cgEntries, taskCgroupEntry{ + // Note: We're guaranteed to have at least one controller, and all + // controllers are guaranteed to be on the same hierarchy. + hierarchyID: ctls[0].HierarchyID(), + controllers: strings.Join(ctlNames, ","), + path: c.Path(), + }) + } + + sort.Slice(cgEntries, func(i, j int) bool { return cgEntries[i].hierarchyID > cgEntries[j].hierarchyID }) + for _, cgE := range cgEntries { + fmt.Fprintf(buf, "%d:%s:%s\n", cgE.hierarchyID, cgE.controllers, cgE.path) + } +} diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index ad59e4f60..b1af1a7ef 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -275,6 +275,10 @@ func (*runExitMain) execute(t *Task) taskRunState { t.fsContext.DecRef(t) t.fdTable.DecRef(t) + // Detach task from all cgroups. This must happen before potentially the + // last ref to the cgroupfs mount is dropped below. + t.LeaveCgroups() + t.mu.Lock() if t.mountNamespaceVFS2 != nil { t.mountNamespaceVFS2.DecRef(t) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index fc18b6253..32031cd70 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -151,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { rseqSignature: cfg.RSeqSignature, futexWaiter: futex.NewWaiter(), containerID: cfg.ContainerID, + cgroups: make(map[Cgroup]struct{}), } t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu @@ -189,6 +190,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { t.parent.children[t] = struct{}{} } + if VFS2Enabled { + t.EnterInitialCgroups(t.parent) + } + if tg.leader == nil { // New thread group. tg.leader = t diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 2c658d001..601fc0d3a 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -30,8 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application") - // SyscallRestartBlock represents the restart block for a syscall restartable // with a custom function. It encapsulates the state required to restart a // syscall across a S/R. @@ -284,7 +282,7 @@ func (*runSyscallExit) execute(t *Task) taskRunState { // indicated by an execution fault at address addr. doVsyscall returns the // task's next run state. func (t *Task) doVsyscall(addr hostarch.Addr, sysno uintptr) taskRunState { - vsyscallCount.Increment() + metric.WeirdnessMetric.Increment("vsyscall_count") // Grab the caller up front, to make sure there's a sensible stack. caller := t.Arch().Native(uintptr(0)) diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 09d070ec8..77ad62445 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -114,6 +114,15 @@ func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) { } } +// forEachTaskLocked applies f to each Task in ts. +// +// Preconditions: ts.mu must be locked (for reading or writing). +func (ts *TaskSet) forEachTaskLocked(f func(t *Task)) { + for t := range ts.Root.tids { + f(t) + } +} + // A PIDNamespace represents a PID namespace, a bimap between thread IDs and // tasks. See the pid_namespaces(7) man page for further details. // diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index ecb6603a1..4c65215fa 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -11,11 +11,12 @@ go_library( "vdso.go", "vdso_state.go", ], + marshal = True, + marshal_debug = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/cpuid", "//pkg/hostarch", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index e92d9fdc3..8fc3e2a79 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" @@ -47,10 +46,10 @@ const ( var ( // header64Size is the size of elf.Header64. - header64Size = int(binary.Size(elf.Header64{})) + header64Size = (*linux.ElfHeader64)(nil).SizeBytes() // Prog64Size is the size of elf.Prog64. - prog64Size = int(binary.Size(elf.Prog64{})) + prog64Size = (*linux.ElfProg64)(nil).SizeBytes() ) func progFlagsAsPerms(f elf.ProgFlag) hostarch.AccessType { @@ -136,7 +135,6 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Unsupported ELF endianness: %v", endian) return elfInfo{}, syserror.ENOEXEC } - byteOrder := binary.LittleEndian if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT { log.Infof("Unsupported ELF version: %v", version) @@ -145,7 +143,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // EI_OSABI is ignored by Linux, which is the only OS supported. os := abi.Linux - var hdr elf.Header64 + var hdr linux.ElfHeader64 hdrBuf := make([]byte, header64Size) _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0) if err != nil { @@ -156,7 +154,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { } return elfInfo{}, err } - binary.Unmarshal(hdrBuf, byteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBuf) // We support amd64 and arm64. var a arch.Arch @@ -213,8 +211,8 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { phdrs := make([]elf.ProgHeader, hdr.Phnum) for i := range phdrs { - var prog64 elf.Prog64 - binary.Unmarshal(phdrBuf[:prog64Size], byteOrder, &prog64) + var prog64 linux.ElfProg64 + prog64.UnmarshalUnsafe(phdrBuf[:prog64Size]) phdrBuf = phdrBuf[prog64Size:] phdrs[i] = elf.ProgHeader{ Type: elf.ProgType(prog64.Type), diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 72868646a..610686ea0 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -375,6 +375,11 @@ type MMapOpts struct { // // If Force is true, Unmap and Fixed must be true. Force bool + + // SentryOwnedContent indicates the sentry exclusively controls the + // underlying memory backing the mapping thus the memory content is + // guaranteed not to be modified outside the sentry's purview. + SentryOwnedContent bool } // File represents a host file that may be mapped into an platform.AddressSpace. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index f04898dc1..b307832fd 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -65,6 +65,7 @@ go_test( name = "kvm_test", srcs = [ "kvm_amd64_test.go", + "kvm_amd64_test.s", "kvm_arm64_test.go", "kvm_test.go", "virtual_map_test.go", diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index fd1131638..bb9967b9f 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -16,7 +16,6 @@ package kvm import ( "fmt" - "reflect" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/ring0" @@ -36,6 +35,14 @@ func sighandler() // dieArchSetup and the assembly implementation for dieTrampoline. func dieTrampoline() +// Return the start address of the functions above. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfSighandler() uintptr +func addrOfDieTrampoline() uintptr + var ( // bounceSignal is the signal used for bouncing KVM. // @@ -87,10 +94,10 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { + if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } // Extract the address for the trampoline. - dieTrampolineAddr = reflect.ValueOf(dieTrampoline).Pointer() + dieTrampolineAddr = addrOfDieTrampoline() } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index 025ea93b5..953024600 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -81,8 +81,20 @@ fallback: MOVQ ·savedHandler(SB), AX JMP AX +// func addrOfSighandler() uintptr +TEXT ·addrOfSighandler(SB), $0-8 + MOVQ $·sighandler(SB), AX + MOVQ AX, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 PUSHQ BX // First argument (vCPU). PUSHQ AX // Fake the old RIP as caller. JMP ·dieHandler(SB) + +// func addrOfDieTrampoline() uintptr +TEXT ·addrOfDieTrampoline(SB), $0-8 + MOVQ $·dieTrampoline(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 09c7e88e5..308f2a951 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -92,6 +92,12 @@ fallback: MOVD ·savedHandler(SB), R7 B (R7) +// func addrOfSighandler() uintptr +TEXT ·addrOfSighandler(SB), $0-8 + MOVD $·sighandler(SB), R0 + MOVD R0, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 // R0: Fake the old PC as caller @@ -99,3 +105,9 @@ TEXT ·dieTrampoline(SB),NOSPLIT,$0 MOVD.P R1, 8(RSP) // R1: First argument (vCPU) MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller B ·dieHandler(SB) + +// func addrOfDieTrampoline() uintptr +TEXT ·addrOfDieTrampoline(SB), $0-8 + MOVD $·dieTrampoline(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index e44e995a0..b8dd1e4a5 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -49,3 +49,40 @@ func TestSegments(t *testing.T) { return false }) } + +// stmxcsr reads the MXCSR control and status register. +func stmxcsr(addr *uint32) + +func TestMXCSR(t *testing.T) { + applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { + var si arch.SignalInfo + switchOpts := ring0.SwitchOpts{ + Registers: regs, + FloatingPointState: &dummyFPState, + PageTables: pt, + FullRestore: true, + } + + const mxcsrControllMask = uint32(0x1f80) + mxcsrBefore := uint32(0) + mxcsrAfter := uint32(0) + stmxcsr(&mxcsrBefore) + if mxcsrBefore == 0 { + // goruntime sets mxcsr to 0x1f80 and it never changes + // the control configuration. + panic("mxcsr is zero") + } + switchOpts.FloatingPointState.SetMXCSR(0) + if _, err := c.SwitchToUser( + switchOpts, &si); err == platform.ErrContextInterrupt { + return true // Retry. + } else if err != nil { + t.Errorf("application syscall failed: %v", err) + } + stmxcsr(&mxcsrAfter) + if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask { + t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter) + } + return false + }) +} diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.s b/pkg/sentry/platform/kvm/kvm_amd64_test.s new file mode 100644 index 000000000..8e9079867 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.s @@ -0,0 +1,21 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// stmxcsr reads the MXCSR control and status register. +TEXT ·stmxcsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), SI + STMXCSR (SI) + RET diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 2492d57be..eb2dcccac 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -66,6 +66,7 @@ const ( _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 _KVM_CAP_VCPU_EVENTS = 0x29 _KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e + _KVM_CAP_TSC_CONTROL = 0x3c ) // KVM limits. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index b3d4188a3..1b5d5f66e 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -67,11 +67,17 @@ type machine struct { // maxSlots is the maximum number of memory slots supported by the machine. maxSlots int + // tscControl checks whether cpu supports TSC scaling + tscControl bool + // usedSlots is the set of used physical addresses (sorted). usedSlots []uintptr // nextID is the next vCPU ID. nextID uint32 + + // machineArchState is the architecture-specific state. + machineArchState } const ( @@ -193,12 +199,7 @@ func newMachine(vm int) (*machine, error) { m.available.L = &m.mu // Pull the maximum vCPUs. - maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) - if errno != 0 { - m.maxVCPUs = _KVM_NR_VCPUS - } else { - m.maxVCPUs = int(maxVCPUs) - } + m.getMaxVCPU() log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) @@ -214,6 +215,11 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of slots is %d.", m.maxSlots) m.usedSlots = make([]uintptr, m.maxSlots) + // Check TSC Scaling + hasTSCControl, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_TSC_CONTROL) + m.tscControl = errno == 0 && hasTSCControl == 1 + log.Debugf("TSC scaling support: %t.", m.tscControl) + // Create the upper shared pagetables and kernel(sentry) pagetables. m.upperSharedPageTables = pagetables.New(newAllocator()) m.mapUpperHalf(m.upperSharedPageTables) @@ -419,9 +425,8 @@ func (m *machine) Get() *vCPU { } } - // Create a new vCPU (maybe). - if int(m.nextID) < m.maxVCPUs { - c := m.newVCPU() + // Get a new vCPU (maybe). + if c := m.getNewVCPU(); c != nil { c.lock() m.vCPUsByTID[tid] = c m.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index e8e209249..9a2337654 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -63,6 +63,9 @@ func (m *machine) initArchState() error { return nil } +type machineArchState struct { +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -213,6 +216,11 @@ func (c *vCPU) setSystemTime() error { // capabilities as it is emulated in KVM. We don't actually use this // capability, but it means that this method should be robust to // different hardware configurations. + + // if tsc scaling is not supported, fallback to legacy mode + if !c.machine.tscControl { + return c.setSystemTimeLegacy() + } rawFreq, err := c.getTSCFreq() if err != nil { return c.setSystemTimeLegacy() @@ -346,6 +354,10 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) // allocations occur. entersyscall() bluepill(c) + // The root table physical page has to be mapped to not fault in iret + // or sysret after switching into a user address space. sysret and + // iret are in the upper half that is global and already mapped. + switchOpts.PageTables.PrefaultRootTable() prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() @@ -490,3 +502,22 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { physical) } } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + maxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + if errno != 0 { + m.maxVCPUs = _KVM_NR_VCPUS + } else { + m.maxVCPUs = int(maxVCPUs) + } +} + +// getNewVCPU create a new vCPU (maybe) +func (m *machine) getNewVCPU() *vCPU { + if int(m.nextID) < m.maxVCPUs { + c := m.newVCPU() + return c + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 03e84d804..8926b1d9f 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -17,6 +17,10 @@ package kvm import ( + "runtime" + "sync/atomic" + + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/ring0/pagetables" @@ -25,6 +29,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" ) +type machineArchState struct { + //initialvCPUs is the machine vCPUs which has initialized but not used + initialvCPUs map[int]*vCPU +} + type vCPUArchState struct { // PCIDs is the set of PCIDs for this vCPU. // @@ -47,7 +56,7 @@ const ( // Beyond a relatively small number, there are likely few perform // benefits, since the TLB has likely long since lost any translations // from more than a few PCIDs past. - poolPCIDs = 8 + poolPCIDs = 128 ) func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { @@ -182,3 +191,30 @@ func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (hostarch.AccessType, return accessType, platform.ErrContextSignal } + +// getMaxVCPU get max vCPU number +func (m *machine) getMaxVCPU() { + rmaxVCPUs := runtime.NumCPU() + smaxVCPUs, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) + // compare the max vcpu number from runtime and syscall, use smaller one. + if errno != 0 { + m.maxVCPUs = rmaxVCPUs + } else { + if rmaxVCPUs < int(smaxVCPUs) { + m.maxVCPUs = rmaxVCPUs + } else { + m.maxVCPUs = int(smaxVCPUs) + } + } +} + +// getNewVCPU() scan for an available vCPU from initialvCPUs +func (m *machine) getNewVCPU() *vCPU { + for CID, c := range m.initialvCPUs { + if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) { + delete(m.initialvCPUs, CID) + return c + } + } + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 634e55ec0..92edc992b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/arch/fpu" "gvisor.dev/gvisor/pkg/sentry/platform" + ktime "gvisor.dev/gvisor/pkg/sentry/time" ) type kvmVcpuInit struct { @@ -47,6 +48,19 @@ func (m *machine) initArchState() error { uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 { panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno)) } + + // Initialize all vCPUs on ARM64, while this does not happen on x86_64. + // The reason for the difference is that ARM64 and x86_64 have different KVM timer mechanisms. + // If we create vCPU dynamically on ARM64, the timer for vCPU would mess up for a short time. + // For more detail, please refer to https://github.com/google/gvisor/issues/5739 + m.initialvCPUs = make(map[int]*vCPU) + m.mu.Lock() + for int(m.nextID) < m.maxVCPUs-1 { + c := m.newVCPU() + c.state = 0 + m.initialvCPUs[c.id] = c + } + m.mu.Unlock() return nil } @@ -174,9 +188,58 @@ func (c *vCPU) setTSC(value uint64) error { return nil } +// getTSC gets the counter Physical Counter minus Virtual Offset. +func (c *vCPU) getTSC() error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + + if err := c.getOneRegister(®); err != nil { + return err + } + + return nil +} + // setSystemTime sets the vCPU to the system time. func (c *vCPU) setSystemTime() error { - return c.setSystemTimeLegacy() + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Use get the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + // replace getTSC to another setOneRegister syscall can get more accurate value? + start := uint64(ktime.Rdtsc()) + if err := c.getTSC(); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && (current <= upperThreshold || minimum < 50) { + // Try to set the TSC + if err := c.setTSC(end + (minimum / 2)); err != nil { + return err + } + return nil + } + } } //go:nosplit @@ -203,7 +266,7 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { uintptr(c.fd), _KVM_GET_ONE_REG, uintptr(unsafe.Pointer(reg))); errno != 0 { - return fmt.Errorf("error setting one register: %v", errno) + return fmt.Errorf("error getting one register: %v", errno) } return nil } diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s index 16f9c523e..d5c3f901f 100644 --- a/pkg/sentry/platform/ptrace/stub_amd64.s +++ b/pkg/sentry/platform/ptrace/stub_amd64.s @@ -109,6 +109,12 @@ parent_dead: SYSCALL HLT +// func addrOfStub() uintptr +TEXT ·addrOfStub(SB), $0-8 + MOVQ $·stub(SB), AX + MOVQ AX, ret+0(FP) + RET + // stubCall calls the stub function at the given address with the given PPID. // // This is a distinct function because stub, above, may be mapped at any diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s index 6162df02a..4664cd4ad 100644 --- a/pkg/sentry/platform/ptrace/stub_arm64.s +++ b/pkg/sentry/platform/ptrace/stub_arm64.s @@ -102,6 +102,12 @@ parent_dead: SVC HLT +// func addrOfStub() uintptr +TEXT ·addrOfStub(SB), $0-8 + MOVD $·stub(SB), R0 + MOVD R0, ret+0(FP) + RET + // stubCall calls the stub function at the given address with the given PPID. // // This is a distinct function because stub, above, may be mapped at any diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go index 5c9b7784f..1fbdea898 100644 --- a/pkg/sentry/platform/ptrace/stub_unsafe.go +++ b/pkg/sentry/platform/ptrace/stub_unsafe.go @@ -26,6 +26,13 @@ import ( // stub is defined in arch-specific assembly. func stub() +// addrOfStub returns the start address of stub. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfStub() uintptr + // stubCall calls the stub at the given address with the given pid. func stubCall(addr, pid uintptr) @@ -41,7 +48,7 @@ func unsafeSlice(addr uintptr, length int) (slice []byte) { // stubInit initializes the stub. func stubInit() { // Grab the existing stub. - stubBegin := reflect.ValueOf(stub).Pointer() + stubBegin := addrOfStub() stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin) stubSlice := unsafeSlice(stubBegin, stubLen) mapLen := uintptr(stubLen) diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 080859125..7ee89a735 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 0e0e82365..2029e7cf4 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -14,9 +14,11 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 45a05cd63..235b9c306 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -18,9 +18,11 @@ package control import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -193,7 +195,7 @@ func putUint32(buf []byte, n uint32) []byte { // putCmsg writes a control message header and as much data as will fit into // the unused capacity of a buffer. func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { - space := binary.AlignDown(cap(buf)-len(buf), 4) + space := bits.AlignDown(cap(buf)-len(buf), 4) // We can't write to space that doesn't exist, so if we are going to align // the available space, we must align down. @@ -230,7 +232,7 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ return alignSlice(buf, align), flags } -func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte { +func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte { if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader { return buf } @@ -241,8 +243,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = putUint32(buf, msgType) hdrBuf := buf - - buf = binary.Marshal(buf, hostarch.ByteOrder, data) + buf = append(buf, marshal.Marshal(data)...) // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { @@ -288,7 +289,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int // alignSlice extends a slice's length (up to the capacity) to align it. func alignSlice(buf []byte, align uint) []byte { - aligned := binary.AlignUp(len(buf), align) + aligned := bits.AlignUp(len(buf), align) if aligned > cap(buf) { // Linux allows unaligned data if there isn't room for alignment. // Since there isn't room for alignment, there isn't room for any @@ -300,12 +301,13 @@ func alignSlice(buf []byte, align uint) []byte { // PackTimestamp packs a SO_TIMESTAMP socket control message. func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { + timestampP := linux.NsecToTimeval(timestamp) return putCmsgStruct( buf, linux.SOL_SOCKET, linux.SO_TIMESTAMP, t.Arch().Width(), - linux.NsecToTimeval(timestamp), + ×tampP, ) } @@ -316,7 +318,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { linux.SOL_TCP, linux.TCP_INQ, t.Arch().Width(), - inq, + primitive.AllocateInt32(inq), ) } @@ -327,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { linux.SOL_IP, linux.IP_TOS, t.Arch().Width(), - tos, + primitive.AllocateUint8(tos), ) } @@ -338,7 +340,7 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { linux.SOL_IPV6, linux.IPV6_TCLASS, t.Arch().Width(), - tClass, + primitive.AllocateUint32(tClass), ) } @@ -423,7 +425,7 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt // cmsgSpace is equivalent to CMSG_SPACE in Linux. func cmsgSpace(t *kernel.Task, dataLen int) int { - return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width()) + return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width()) } // CmsgsSpace returns the number of bytes needed to fit the control messages @@ -475,7 +477,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, syserror.EINVAL @@ -491,7 +493,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.SOL_SOCKET: switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) numRights := rightsSize / linux.SizeOfControlMessageRight if len(fds)+numRights > linux.SCM_MAX_FD { @@ -502,7 +504,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -510,23 +512,23 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, syserror.EINVAL } var ts linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts) + ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) cmsgs.IP.Timestamp = ts.ToNsecCapped() cmsgs.IP.HasTimestamp = true - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: // Unknown message type. @@ -539,8 +541,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTOS = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS) - i += binary.AlignUp(length, width) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + cmsgs.IP.TOS = uint8(tos) + i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -549,19 +553,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) cmsgs.IP.PacketInfo = packetInfo - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -571,7 +575,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL @@ -583,17 +587,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTClass = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass) - i += binary.AlignUp(length, width) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + cmsgs.IP.TClass = uint32(tclass) + i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -603,7 +609,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a5c2155a2..2e3064565 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -17,7 +17,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fdnotifier", "//pkg/hostarch", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a784e23b5..52ae4bc9c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" @@ -528,24 +527,28 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], hostarch.ByteOrder, &controlMessages.IP.Timestamp) + ts := linux.Timeval{} + ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) + controlMessages.IP.Timestamp = ts.ToNsecCapped() } case linux.SOL_IP: switch unixCmsg.Header.Type { case linux.IP_TOS: controlMessages.IP.HasTOS = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -558,11 +561,13 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -575,7 +580,9 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.TCP_INQ: controlMessages.IP.HasInq = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq) + var inq primitive.Int32 + inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + controlMessages.IP.Inq = int32(inq) } } } @@ -689,7 +696,7 @@ func (s *socketOpsCommon) State() uint32 { return 0 } - binary.Unmarshal(buf, hostarch.ByteOrder, &info) + info.UnmarshalUnsafe(buf[:info.SizeBytes()]) return uint32(info.State) } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 26e8ae17a..393a1ab3a 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -15,6 +15,7 @@ package hostinet import ( + "encoding/binary" "fmt" "io" "io/ioutil" @@ -26,10 +27,10 @@ import ( "syscall" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -147,8 +148,8 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(link.Data) < unix.SizeofIfInfomsg { return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } - var ifinfo unix.IfInfomsg - binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo) + var ifinfo linux.InterfaceInfoMessage + ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -178,11 +179,11 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(addr.Data) < unix.SizeofIfAddrmsg { return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } - var ifaddr unix.IfAddrmsg - binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr) + var ifaddr linux.InterfaceAddrMessage + ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, - PrefixLen: ifaddr.Prefixlen, + PrefixLen: ifaddr.PrefixLen, Flags: ifaddr.Flags, } attrs, err := syscall.ParseNetlinkRouteAttr(&addr) @@ -210,13 +211,13 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) continue } - var ifRoute unix.RtMsg - binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute) + var ifRoute linux.RouteMessage + ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) inetRoute := inet.Route{ Family: ifRoute.Family, - DstLen: ifRoute.Dst_len, - SrcLen: ifRoute.Src_len, - TOS: ifRoute.Tos, + DstLen: ifRoute.DstLen, + SrcLen: ifRoute.SrcLen, + TOS: ifRoute.TOS, Table: ifRoute.Table, Protocol: ifRoute.Protocol, Scope: ifRoute.Scope, @@ -245,7 +246,9 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) if len(attr.Value) != expected { return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected) } - binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface) + var outputIF primitive.Int32 + outputIF.UnmarshalUnsafe(attr.Value) + inetRoute.OutputInterface = int32(outputIF) } } diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 4381dfa06..61b2c9755 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -14,14 +14,16 @@ go_library( "tcp_matcher.go", "udp_matcher.go", ], + marshal = True, # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/hostarch", "//pkg/log", + "//pkg/marshal", "//pkg/sentry/kernel", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 4bd305a44..6fc7781ad 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -79,7 +78,7 @@ func marshalEntryMatch(name string, data []byte) []byte { nflog("marshaling matcher %q", name) // We have to pad this struct size to a multiple of 8 bytes. - size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) + size := bits.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) matcher := linux.KernelXTEntryMatch{ XTEntryMatch: linux.XTEntryMatch{ MatchSize: uint16(size), @@ -88,9 +87,11 @@ func marshalEntryMatch(name string, data []byte) []byte { } copy(matcher.Name[:], name) - buf := make([]byte, 0, size) - buf = binary.Marshal(buf, hostarch.ByteOrder, matcher) - return append(buf, make([]byte, size-len(buf))...) + buf := make([]byte, size) + entryLen := matcher.XTEntryMatch.SizeBytes() + matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) + copy(buf[entryLen:], matcher.Data) + return buf } func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 1fc4cb651..cb78ef60b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -141,10 +139,9 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IPTEntry - buf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIPTEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 67a52b628..5cb7fe4aa 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -144,10 +142,9 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IP6TEntry - buf := optVal[:linux.SizeOfIP6TEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIP6TEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 5200e08ed..f42d73178 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -22,7 +22,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -121,7 +120,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } @@ -146,7 +145,7 @@ func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } @@ -179,7 +178,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] - binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace) + replace.UnmarshalBytes(replaceBuf) // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table @@ -274,10 +273,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, - // make sure all other chains point to ACCEPT rules. + // Since we don't support FORWARD, yet, make sure all other chains point to + // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if hook := stack.Hook(hook); hook == stack.Forward { if ruleIdx == stack.HookUnset { continue } @@ -309,8 +308,8 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:linux.SizeOfXTEntryMatch] - binary.Unmarshal(buf, hostarch.ByteOrder, &match) + buf := optVal[:match.SizeBytes()] + match.UnmarshalUnsafe(buf) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index b2cc6be20..60845cab3 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,8 +58,8 @@ func (ownerMarshaler) marshal(mr matcher) []byte { } } - buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) - return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo)) + buf := marshal.Marshal(&iptOwnerInfo) + return marshalEntryMatch(matcherNameOwner, buf) } // unmarshal implements matchMaker.unmarshal. @@ -72,7 +71,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 80f8c6430..e94aceb92 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,11 +15,12 @@ package netfilter import ( + "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,6 +36,11 @@ const ErrorTargetName = "ERROR" // change the destination port and/or IP for packets. const RedirectTargetName = "REDIRECT" +// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should +// be reached for only NAT table. These targets will change the source port +// and/or IP for packets. +const SNATTargetName = "SNAT" + func init() { // Standard targets include ACCEPT, DROP, RETURN, and JUMP. registerTargetMaker(&standardTargetMaker{ @@ -59,6 +65,13 @@ func init() { registerTargetMaker(&nfNATTargetMaker{ NetworkProtocol: header.IPv6ProtocolNumber, }) + + registerTargetMaker(&snatTargetMakerV4{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&snatTargetMakerV6{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) } // The stack package provides some basic, useful targets for us. The following @@ -131,6 +144,17 @@ func (rt *redirectTarget) id() targetID { } } +type snatTarget struct { + stack.SNATTarget +} + +func (st *snatTarget) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + } +} + type standardTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } @@ -166,8 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte { Verdict: verdict, } - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -176,8 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - buf = buf[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget) + standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -222,8 +244,7 @@ func (*errorTargetMaker) marshal(target target) []byte { copy(xt.Name[:], errorName) copy(xt.Target.Name[:], ErrorTargetName) - ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -233,7 +254,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt) + errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: // * An actual error case. These rules have an error named @@ -276,12 +297,11 @@ func (*redirectTargetMaker) marshal(target target) []byte { } copy(xt.Target.Name[:], RedirectTargetName) - ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -297,7 +317,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &rt) + rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. target := redirectTarget{RedirectTarget: stack.RedirectTarget{ @@ -336,12 +356,13 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return &target, nil } +// +marshal type nfNATTarget struct { Target linux.XTEntryTarget Range linux.NFNATRange } -const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange +const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange type nfNATTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber @@ -358,7 +379,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte { rt := target.(*redirectTarget) nt := nfNATTarget{ Target: linux.XTEntryTarget{ - TargetSize: nfNATMarhsalledSize, + TargetSize: nfNATMarshalledSize, }, Range: linux.NFNATRange{ Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, @@ -371,12 +392,11 @@ func (*nfNATTargetMaker) marshal(target target) []byte { nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto - ret := make([]byte, 0, nfNATMarhsalledSize) - return binary.Marshal(ret, hostarch.ByteOrder, nt) + return marshal.Marshal(&nt) } func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { - if size := nfNATMarhsalledSize; len(buf) < size { + if size := nfNATMarshalledSize; len(buf) < size { nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) return nil, syserr.ErrInvalidArgument } @@ -387,8 +407,8 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize] - binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] + natRange.UnmarshalUnsafe(buf) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -418,6 +438,159 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return &target, nil } +type snatTargetMakerV4 struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (st *snatTargetMakerV4) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + } +} + +func (*snatTargetMakerV4) marshal(target target) []byte { + st := target.(*snatTarget) + // This is a snat target named snat. + xt := linux.XTSNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTSNATTarget, + }, + } + copy(xt.Target.Name[:], SNATTargetName) + + xt.NfRange.RangeSize = 1 + xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeIPV4.MinPort = htons(st.Port) + xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort + copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr) + copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr) + return marshal.Marshal(&xt) +} + +func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { + if len(buf) < linux.SizeOfXTSNATTarget { + nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("snatTargetMakerV4: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var st linux.XTSNATTarget + buf = buf[:linux.SizeOfXTSNATTarget] + st.UnmarshalUnsafe(buf) + + // Copy linux.XTSNATTarget to stack.SNATTarget. + target := snatTarget{SNATTarget: stack.SNATTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }} + + // RangeSize should be 1. + nfRange := st.NfRange + if nfRange.RangeSize != 1 { + nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port, + // choose one automatically. + if nfRange.RangeIPV4.MinPort == 0 { + nflog("snatTargetMakerV4: snat target needs to specify a non-zero port") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP { + nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + + target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.Port = ntohs(nfRange.RangeIPV4.MinPort) + + return &target, nil +} + +type snatTargetMakerV6 struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (st *snatTargetMakerV6) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + revision: 1, + } +} + +func (*snatTargetMakerV6) marshal(target target) []byte { + st := target.(*snatTarget) + nt := nfNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: nfNATMarshalledSize, + }, + Range: linux.NFNATRange{ + Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED, + }, + } + copy(nt.Target.Name[:], SNATTargetName) + copy(nt.Range.MinAddr[:], st.Addr) + copy(nt.Range.MaxAddr[:], st.Addr) + nt.Range.MinProto = htons(st.Port) + nt.Range.MaxProto = nt.Range.MinProto + + return marshal.Marshal(&nt) +} + +func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { + if size := nfNATMarshalledSize; len(buf) < size { + nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("snatTargetMakerV6: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var natRange linux.NFNATRange + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] + natRange.UnmarshalUnsafe(buf) + + // TODO(gvisor.dev/issue/5689): Support port or address ranges. + if natRange.MinAddr != natRange.MaxAddr { + nflog("snatTargetMakerV6: MinAddr and MaxAddr are different") + return nil, syserr.ErrInvalidArgument + } + if natRange.MinProto != natRange.MaxProto { + nflog("snatTargetMakerV6: MinProto and MaxProto are different") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags. + if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags) + return nil, syserr.ErrInvalidArgument + } + + target := snatTarget{ + SNATTarget: stack.SNATTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Addr: tcpip.Address(natRange.MinAddr[:]), + Port: ntohs(natRange.MinProto), + }, + } + + return &target, nil +} + // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { @@ -453,8 +626,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &target) + target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) return unmarshalTarget(target, filter, optVal) } @@ -480,7 +652,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 69557f515..95bb9826e 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTTCP) - return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp)) + return marshalEntryMatch(matcherNameTCP, marshal.Marshal(&xttcp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 6a60e6bd6..fb8be27e6 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (udpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTUDP) - return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp)) + return marshalEntryMatch(matcherNameUDP, marshal.Marshal(&xtudp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 171b95c63..64cd263da 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -14,7 +14,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", @@ -50,5 +50,7 @@ go_test( deps = [ ":netlink", "//pkg/abi/linux", + "//pkg/marshal", + "//pkg/marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index ab0e68af7..80385bfdc 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -19,15 +19,17 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" ) // alignPad returns the length of padding required for alignment. // // Preconditions: align is a power of two. func alignPad(length int, align uint) int { - return binary.AlignUp(length, align) - length + return bits.AlignUp(length, align) - length } // Message contains a complete serialized netlink message. @@ -42,7 +44,7 @@ type Message struct { func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ hdr: hdr, - buf: binary.Marshal(nil, hostarch.ByteOrder, hdr), + buf: marshal.Marshal(&hdr), } } @@ -58,7 +60,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { return } var hdr linux.NetlinkMessageHeader - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) // Msg portion. totalMsgLen := int(hdr.Length) @@ -92,7 +94,7 @@ func (m *Message) Header() linux.NetlinkMessageHeader { // GetData unmarshals the payload message header from this netlink message, and // returns the attributes portion. -func (m *Message) GetData(msg interface{}) (AttrsView, bool) { +func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) { b := BytesView(m.buf) _, ok := b.Extract(linux.NetlinkMessageHeaderSize) @@ -100,12 +102,12 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) { return nil, false } - size := int(binary.Size(msg)) + size := msg.SizeBytes() msgBytes, ok := b.Extract(size) if !ok { return nil, false } - binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg) + msg.UnmarshalUnsafe(msgBytes) numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) // Linux permits the last message not being aligned, just consume all of it. @@ -131,7 +133,7 @@ func (m *Message) Finalize() []byte { // Align the message. Note that the message length in the header (set // above) is the useful length of the message, not the total aligned // length. See net/netlink/af_netlink.c:__nlmsg_put. - aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) + aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) m.putZeros(aligned - len(m.buf)) return m.buf } @@ -145,45 +147,45 @@ func (m *Message) putZeros(n int) { } // Put serializes v into the message. -func (m *Message) Put(v interface{}) { - m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v) +func (m *Message) Put(v marshal.Marshallable) { + m.buf = append(m.buf, marshal.Marshal(v)...) } // PutAttr adds v to the message as a netlink attribute. // // Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize + -// binary.Size(v) fits in math.MaxUint16 bytes. -func (m *Message) PutAttr(atype uint16, v interface{}) { - l := linux.NetlinkAttrHeaderSize + int(binary.Size(v)) +// v.SizeBytes()) fits in math.MaxUint16 bytes. +func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) { + l := linux.NetlinkAttrHeaderSize + v.SizeBytes() if l > math.MaxUint16 { panic(fmt.Sprintf("attribute too large: %d", l)) } - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) m.Put(v) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } // PutAttrString adds s to the message as a netlink attribute. func (m *Message) PutAttrString(atype uint16, s string) { l := linux.NetlinkAttrHeaderSize + len(s) + 1 - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) // String + NUL-termination. - m.Put([]byte(s)) + m.Put(primitive.AsByteSlice([]byte(s))) m.putZeros(1) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } @@ -251,7 +253,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest if !ok { return } - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) if !ok { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index ef13d9386..968968469 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -20,13 +20,31 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) type dummyNetlinkMsg struct { + marshal.StubMarshallable Foo uint16 } +func (*dummyNetlinkMsg) SizeBytes() int { + return 2 +} + +func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { + p := primitive.Uint16(m.Foo) + p.MarshalUnsafe(dst) +} + +func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { + var p primitive.Uint16 + p.UnmarshalUnsafe(src) + m.Foo = uint16(p) +} + func TestParseMessage(t *testing.T) { tests := []struct { desc string diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 744fc74f4..c6c04b4e3 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 5a2255db3..86f6419dc 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -167,7 +168,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { Type: linux.RTM_NEWLINK, }) - m.Put(linux.InterfaceInfoMessage{ + m.Put(&linux.InterfaceInfoMessage{ Family: linux.AF_UNSPEC, Type: i.DeviceType, Index: idx, @@ -175,7 +176,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { }) m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) + m.PutAttr(linux.IFLA_MTU, primitive.AllocateUint32(i.MTU)) mac := make([]byte, 6) brd := mac @@ -183,8 +184,8 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { mac = i.Addr brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + m.PutAttr(linux.IFLA_ADDRESS, primitive.AsByteSlice(mac)) + m.PutAttr(linux.IFLA_BROADCAST, primitive.AsByteSlice(brd)) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -216,14 +217,15 @@ func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netl Type: linux.RTM_NEWADDR, }) - m.Put(linux.InterfaceAddrMessage{ + m.Put(&linux.InterfaceAddrMessage{ Family: a.Family, PrefixLen: a.PrefixLen, Index: uint32(id), }) - m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) - m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) + addr := primitive.ByteSlice([]byte(a.Addr)) + m.PutAttr(linux.IFA_LOCAL, &addr) + m.PutAttr(linux.IFA_ADDRESS, &addr) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -366,7 +368,7 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Type: linux.RTM_NEWROUTE, }) - m.Put(linux.RouteMessage{ + m.Put(&linux.RouteMessage{ Family: rt.Family, DstLen: rt.DstLen, SrcLen: rt.SrcLen, @@ -382,18 +384,18 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Flags: rt.Flags, }) - m.PutAttr(254, []byte{123}) + m.PutAttr(254, primitive.AsByteSlice([]byte{123})) if rt.DstLen > 0 { - m.PutAttr(linux.RTA_DST, rt.DstAddr) + m.PutAttr(linux.RTA_DST, primitive.AsByteSlice(rt.DstAddr)) } if rt.SrcLen > 0 { - m.PutAttr(linux.RTA_SRC, rt.SrcAddr) + m.PutAttr(linux.RTA_SRC, primitive.AsByteSlice(rt.SrcAddr)) } if rt.OutputInterface != 0 { - m.PutAttr(linux.RTA_OIF, rt.OutputInterface) + m.PutAttr(linux.RTA_OIF, primitive.AllocateInt32(rt.OutputInterface)) } if len(rt.GatewayAddr) > 0 { - m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr) + m.PutAttr(linux.RTA_GATEWAY, primitive.AsByteSlice(rt.GatewayAddr)) } // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -503,7 +505,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms hdr := msg.Header() // All messages start with a 1 byte protocol family. - var family uint8 + var family primitive.Uint8 if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 30c297149..d75a2879f 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -20,7 +20,6 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -223,7 +222,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa) + sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument @@ -338,16 +337,14 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - sendBufferSizeP := primitive.Int32(s.sendBufferSize) - return &sendBufferSizeP, nil + return primitive.AllocateInt32(int32(s.sendBufferSize)), nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - recvBufferSizeP := primitive.Int32(math.MaxInt32) - return &recvBufferSizeP, nil + return primitive.AllocateInt32(math.MaxInt32), nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -484,7 +481,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * Family: linux.AF_NETLINK, PortID: uint32(s.portID), } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // GetPeerName implements socket.Socket.GetPeerName. @@ -495,7 +492,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * // must be the kernel. PortID: 0, } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // RecvMsg implements socket.Socket.RecvMsg. @@ -504,7 +501,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Family: linux.AF_NETLINK, PortID: 0, } - fromLen := uint32(binary.Size(from)) + fromLen := uint32(from.SizeBytes()) trunc := flags&linux.MSG_TRUNC != 0 @@ -640,7 +637,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys }) // Add the dump_done_errno payload. - m.Put(int64(0)) + m.Put(primitive.AllocateInt64(0)) _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { @@ -658,7 +655,7 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ + m.Put(&linux.NetlinkErrorMessage{ Error: int32(-err.ToLinux().Number()), Header: hdr, }) @@ -668,7 +665,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ + m.Put(&linux.NetlinkErrorMessage{ Error: 0, Header: hdr, }) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 0b39a5b67..9561b7c25 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -19,7 +19,6 @@ go_library( ], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index ed6572bab..60ef33360 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "encoding/binary" "fmt" "io" "io/ioutil" @@ -35,7 +36,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" @@ -199,6 +199,13 @@ var Metrics = tcpip.Stats{ OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Number of record route options found in received IP packets."), OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."), OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."), + Forwarding: tcpip.IPForwardingStats{ + Unrouteable: mustCreateMetric("/netstack/ip/forwarding/unrouteable", "Number of IP packets received which couldn't be routed and thus were not forwarded."), + ExhaustedTTL: mustCreateMetric("/netstack/ip/forwarding/exhausted_ttl", "Number of IP packets received which could not be forwarded due to an exhausted TTL."), + LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), + LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), + Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), + }, }, ARP: tcpip.ARPStats{ PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."), @@ -242,6 +249,7 @@ var Metrics = tcpip.Stats{ FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."), Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."), ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), + FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."), }, UDP: tcpip.UDPStats{ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), @@ -374,9 +382,9 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue }), nil } -var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) -var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) -var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) +var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes() +var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes() +var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes() // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -612,7 +620,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) if a.Protocol != uint16(s.protocol) { return syserr.ErrInvalidArgument @@ -885,10 +893,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } + size := ep.SocketOptions().GetReceiveBufferSize() if size > math.MaxInt32 { size = math.MaxInt32 @@ -1314,7 +1319,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return &v, nil case linux.IP6T_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet6{})) { + if outLen < sockAddrInet6Size { return nil, syserr.ErrInvalidArgument } @@ -1511,7 +1516,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return &v, nil case linux.SO_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet{})) { + if outLen < sockAddrInetSize { return nil, syserr.ErrInvalidArgument } @@ -1661,7 +1666,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := hostarch.ByteOrder.Uint32(optVal) - ep.SocketOptions().SetSendBufferSize(int64(v), true) + ep.SocketOptions().SetSendBufferSize(int64(v), true /* notify */) return nil case linux.SO_RCVBUF: @@ -1670,7 +1675,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := hostarch.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) + ep.SocketOptions().SetReceiveBufferSize(int64(v), true /* notify */) + return nil case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -1743,7 +1749,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1756,7 +1762,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1792,7 +1798,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + + if v != (linux.Linger{}) { + socket.SetSockOptEmitUnimplementedEvent(t, name) + } ep.SocketOptions().SetLinger(tcpip.LingerOption{ Enabled: v.OnOff != 0, @@ -2091,9 +2101,9 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } var ( - inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{})) - inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{})) - inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{})) + inetMulticastRequestSize = (*linux.InetMulticastRequest)(nil).SizeBytes() + inetMulticastRequestWithNICSize = (*linux.InetMulticastRequestWithNIC)(nil).SizeBytes() + inet6MulticastRequestSize = (*linux.Inet6MulticastRequest)(nil).SizeBytes() ) // copyInMulticastRequest copies in a variable-size multicast request. The @@ -2118,12 +2128,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) return req, nil } var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest) + req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) return req, nil } @@ -2133,7 +2143,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) return req, nil } @@ -3102,8 +3112,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe continue } // Populate ifr.ifr_netmask (type sockaddr). - hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET)) - hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0) + hostarch.ByteOrder.PutUint16(ifr.Data[0:], uint16(linux.AF_INET)) + hostarch.ByteOrder.PutUint16(ifr.Data[2:], 0) var mask uint32 = 0xffffffff << (32 - addr.PrefixLen) // Netmask is expected to be returned as a big endian // value. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 4c3d48096..9e56487a6 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -24,7 +24,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -572,19 +571,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -716,7 +715,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInetSize]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -729,7 +728,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -745,7 +744,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 159b8f90f..33f9aeb06 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -130,7 +130,8 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv } ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } @@ -175,8 +176,9 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) return ep } @@ -299,8 +301,9 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits) + ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) + ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize} readQueue.InitRefs() @@ -343,11 +346,11 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn return nil default: - // Busy; return ECONNREFUSED per spec. + // Busy; return EAGAIN per spec. ne.Close(ctx) e.Unlock() ce.Unlock() - return syserr.ErrConnectionRefused + return syserr.ErrTryAgain } } @@ -366,6 +369,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint // to reflect this endpoint's send buffer size. if bufSz := e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()); bufSz != e.ops.GetSendBufferSize() { e.ops.SetSendBufferSize(bufSz, false /* notify */) + e.ops.SetReceiveBufferSize(bufSz, false /* notify */) } } diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go index 590b0bd01..b20334d4f 100644 --- a/pkg/sentry/socket/unix/transport/connectioned_state.go +++ b/pkg/sentry/socket/unix/transport/connectioned_state.go @@ -54,5 +54,5 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd // afterLoad is invoked by stateify. func (e *connectionedEndpoint) afterLoad() { - e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) } diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index d0df28b59..61338728a 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -45,7 +45,8 @@ func NewConnectionless(ctx context.Context) Endpoint { q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go index 2ef337ec8..1bb71baf7 100644 --- a/pkg/sentry/socket/unix/transport/connectionless_state.go +++ b/pkg/sentry/socket/unix/transport/connectionless_state.go @@ -16,5 +16,5 @@ package transport // afterLoad is invoked by stateify. func (e *connectionlessEndpoint) afterLoad() { - e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0c5f5ab42..837ab4fde 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -868,11 +868,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } + log.Warningf("Unsupported socket option: %d", opt) return nil } @@ -905,19 +901,6 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } return int(v), nil - case tcpip.ReceiveBufferSizeOption: - e.Lock() - if e.receiver == nil { - e.Unlock() - return -1, &tcpip.ErrNotConnected{} - } - v := e.receiver.RecvMaxQueueSize() - e.Unlock() - if v < 0 { - return -1, &tcpip.ErrQueueSizeNotSupported{} - } - return int(v), nil - default: log.Warningf("Unsupported socket option: %d", opt) return -1, &tcpip.ErrUnknownProtocolOption{} @@ -1029,3 +1012,15 @@ func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption { Max: maxBufferSize, } } + +// getReceiveBufferLimits implements tcpip.GetReceiveBufferLimits. +// +// We define min, max and default values for unix socket implementation. Unix +// sockets do not use receive buffer. +func getReceiveBufferLimits(tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + return tcpip.ReceiveBufferSizeOption{ + Min: minimumBufferSize, + Default: defaultBufferSize, + Max: maxBufferSize, + } +} diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 2ebd77f82..1fbbd133c 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -25,7 +25,6 @@ go_library( ":strace_go_proto", "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", "//pkg/hostarch", diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index 71b92eaee..d66befe81 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -371,6 +371,7 @@ var linuxAMD64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index bd7361a52..1a2d7d75f 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -312,6 +312,7 @@ var linuxARM64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index e5b7f9b96..f4aab25b0 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -20,14 +20,13 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - - "gvisor.dev/gvisor/pkg/hostarch" ) // SocketFamily are the possible socket(2) families. @@ -162,6 +161,15 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } +func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { + count := len(src) / linux.SizeOfControlMessageRight + cmr := make(linux.ControlMessageRights, count) + for i, _ := range cmr { + cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) + } + return cmr +} + func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string { if length > maxBytes { return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length) @@ -181,7 +189,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) var skipData bool level := "SOL_SOCKET" @@ -221,18 +229,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) continue } switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) - - numRights := rightsSize / linux.SizeOfControlMessageRight - fds := make(linux.ControlMessageRights, numRights) - binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds) - + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) rights := make([]string, 0, len(fds)) for _, fd := range fds { rights = append(rights, fmt.Sprint(fd)) @@ -258,7 +262,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", @@ -282,7 +286,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var tv linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv) + tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", @@ -296,7 +300,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) default: panic("unreachable") } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index e115683f8..3b4d79889 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -119,7 +119,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { } // WaitEpoll implements the epoll_wait(2) linux syscall. -func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) { +func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux.EpollEvent, error) { // Get epoll from the file descriptor. epollfile := t.GetFile(fd) if epollfile == nil { @@ -136,7 +136,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // Try to read events and return right away if we got them or if the // caller requested a non-blocking "wait". r := e.ReadEvents(max) - if len(r) != 0 || timeout == 0 { + if len(r) != 0 || timeoutInNanos == 0 { return r, nil } @@ -144,8 +144,8 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // and register with the epoll object for readability events. var haveDeadline bool var deadline ktime.Time - if timeout > 0 { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index efec93f73..6eabfd219 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -29,10 +29,17 @@ import ( ) var ( - partialResultMetric = metric.MustCreateNewUint64Metric("/syscalls/partial_result", true /* sync */, "Whether or not a partial result has occurred for this sandbox.") - partialResultOnce sync.Once + partialResultOnce sync.Once ) +// incrementPartialResultMetric increments PartialResultMetric by calling +// Increment(). This is added as the func Do() which is called below requires +// us to pass a function which does not take any arguments, whereas Increment() +// takes a variadic number of arguments. +func incrementPartialResultMetric() { + metric.WeirdnessMetric.Increment("partial_result") +} + // HandleIOErrorVFS2 handles special error cases for partial results. For some // errors, we may consume the error and return only the partial read/write. // @@ -48,7 +55,7 @@ func HandleIOErrorVFS2(ctx context.Context, partialResult bool, ioerr, intr erro root := vfs.RootFromContext(ctx) name, _ := fs.PathnameWithDeleted(ctx, root, f.VirtualDentry()) log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, ioerr, ioerr, op, name) - partialResultOnce.Do(partialResultMetric.Increment) + partialResultOnce.Do(incrementPartialResultMetric) } return nil } @@ -66,7 +73,7 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o // An unknown error is encountered with a partial read/write. name, _ := f.Dirent.FullName(nil /* ignore chroot */) log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, ioerr, ioerr, op, name, f.FileOperations) - partialResultOnce.Do(partialResultMetric.Increment) + partialResultOnce.Do(incrementPartialResultMetric) } return nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 2d2212605..090c5ffcb 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -404,6 +404,7 @@ var AMD64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{ 0xffffffffff600000: 96, // vsyscall gettimeofday(2) @@ -722,6 +723,7 @@ var ARM64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{}, Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index 7f460d30b..69cbc98d0 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" @@ -104,14 +105,8 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements the epoll_wait(2) linux syscall. -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - - r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout) +func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { + r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos) if err != nil { return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) } @@ -123,6 +118,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return uintptr(len(r)), nil, nil + +} + +// EpollWait implements the epoll_wait(2) linux syscall. +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + // Convert milliseconds to nanoseconds. + timeoutInNanos := int64(args[3].Int()) * 1000000 + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements the epoll_pwait(2) linux syscall. @@ -144,4 +150,38 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } +// EpollPwait2 implements the epoll_pwait(2) linux syscall. +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + timeout, err := copyTimespecIn(t, timeoutPtr) + if err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + + } + + if maskAddr != 0 { + mask, err := CopyInSigSet(t, maskAddr, maskSize) + if err != nil { + return 0, nil, err + } + + oldmask := t.SignalMask() + t.SetSignalMask(mask) + t.SetSavedSignalMask(oldmask) + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} + // LINT.ThenChange(vfs2/epoll.go) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9bdf6d3d8..e07917613 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -35,12 +35,6 @@ import ( // LINT.IfChange -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 - -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 - // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -367,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFile(fd) @@ -382,14 +379,23 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 backlog = maxListenBacklog } + // Accept one more than the configured listen backlog to keep in parity with + // Linux. Ref, because of missing equality check here: + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937 + // + // In case of unix domain sockets, the following check + // https://github.com/torvalds/linux/blob/7d6beb71da3/net/unix/af_unix.c#L1293 + // will allow 1 connect through since it checks for a receive queue len > + // backlog and not >=. + backlog++ + return 0, nil, s.Listen(t, int(backlog)).ToError() } @@ -457,8 +463,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index b980aa43e..047d955b6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -118,13 +119,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements Linux syscall epoll_wait(2). -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - +func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS { return 0, nil, syserror.EINVAL @@ -158,7 +153,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return 0, nil, err } - if timeout == 0 { + if timeoutInNanos == 0 { return 0, nil, nil } // In the first iteration of this loop, register with the epoll @@ -173,8 +168,8 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer epfile.EventUnregister(&w) } else { // Set up the timer if a timeout was specified. - if timeout > 0 && !haveDeadline { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 && !haveDeadline { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } @@ -186,6 +181,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } } } + +} + +// EpollWait implements Linux syscall epoll_wait(2). +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutInNanos := int64(args[3].Int()) * 1000000 + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements Linux syscall epoll_pwait(2). @@ -199,3 +205,29 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } + +// EpollPwait2 implements Linux syscall epoll_pwait(2). +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + var timeout linux.Timespec + if _, err := timeout.CopyIn(t, timeoutPtr); err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + } + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index a87a66146..69f69e3af 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -35,12 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" ) -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 - -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 - // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -371,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFileVFS2(fd) @@ -386,14 +383,23 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 backlog = maxListenBacklog } + // Accept one more than the configured listen backlog to keep in parity with + // Linux. Ref, because of missing equality check here: + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937 + // + // In case of unix domain sockets, the following check + // https://github.com/torvalds/linux/blob/7d6beb71da3/net/unix/af_unix.c#L1293 + // will allow 1 connect through since it checks for a receive queue len > + // backlog and not >=. + backlog++ + return 0, nil, s.Listen(t, int(backlog)).ToError() } @@ -461,8 +467,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index c50fd97eb..0fc81e694 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -159,6 +159,7 @@ func Override() { s.Table[327] = syscalls.Supported("preadv2", Preadv2) s.Table[328] = syscalls.Supported("pwritev2", Pwritev2) s.Table[332] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() // Override ARM64. @@ -269,6 +270,7 @@ func Override() { s.Table[286] = syscalls.Supported("preadv2", Preadv2) s.Table[287] = syscalls.Supported("pwritev2", Pwritev2) s.Table[291] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() } diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 87d8687ce..1f617ca8f 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -32,6 +32,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/gohacks", "//pkg/log", "//pkg/metric", "//pkg/sync", diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index f9a93115d..39bf1e0de 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -25,11 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// fallbackMetric tracks failed updates. It is not sync, as it is not critical -// that all occurrences are captured and CalibratedClock may fallback many -// times. -var fallbackMetric = metric.MustCreateNewUint64Metric("/time/fallback", false /* sync */, "Incremented when a clock falls back to system calls due to a failed update") - // CalibratedClock implements a clock that tracks a reference clock. // // Users should call Update at regular intervals of around approxUpdateInterval @@ -102,7 +97,7 @@ func (c *CalibratedClock) resetLocked(str string, v ...interface{}) { c.Warningf(str+" Resetting clock; time may jump.", v...) c.ready = false c.ref.Reset() - fallbackMetric.Increment() + metric.WeirdnessMetric.Increment("time_fallback") } // updateParams updates the timekeeping parameters based on the passed diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index f612a71b2..176bcc242 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -524,7 +524,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St Start: fd.vd, }) stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } return fd.impl.Stat(ctx, opts) @@ -539,7 +539,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetStat(ctx, opts) @@ -555,7 +555,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vd, }) statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } return fd.impl.StatFS(ctx) @@ -701,7 +701,7 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string Start: fd.vd, }) names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, err } names, err := fd.impl.ListXattr(ctx, size) @@ -730,7 +730,7 @@ func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) Start: fd.vd, }) val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, err } return fd.impl.GetXattr(ctx, *opts) @@ -746,7 +746,7 @@ func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetXattr(ctx, *opts) @@ -762,7 +762,7 @@ func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error { Start: fd.vd, }) err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.RemoveXattr(ctx, name) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 1556b41a3..b87d9690a 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -252,6 +252,9 @@ type WritableDynamicBytesSource interface { // are backed by a bytes.Buffer that is regenerated when necessary, consistent // with Linux's fs/seq_file.c:single_open(). // +// If data additionally implements WritableDynamicBytesSource, writes are +// dispatched to the implementer. The source data is not automatically modified. +// // DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first // use. // diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 922f9e697..82fd382c2 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -826,6 +826,9 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi if mnt.Flags.NoExec { opts += ",noexec" } + if mopts := mnt.fs.Impl().MountOptions(); mopts != "" { + opts += "," + mopts + } // Format: // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order> @@ -970,17 +973,22 @@ func superBlockOpts(mountPath string, mnt *Mount) string { opts += "," + mopts } - // NOTE(b/147673608): If the mount is a cgroup, we also need to include - // the cgroup name in the options. For now we just read that from the - // path. + // NOTE(b/147673608): If the mount is a ramdisk-based fake cgroupfs, we also + // need to include the cgroup name in the options. For now we just read that + // from the path. Note that this is only possible when "cgroup" isn't + // registered as a valid filesystem type. // - // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we - // should get this value from the cgroup itself, and not rely on the - // path. + // TODO(gvisor.dev/issue/190): Once we removed fake cgroupfs support, we + // should remove this. + if cgroupfs := mnt.vfs.getFilesystemType("cgroup"); cgroupfs != nil && cgroupfs.opts.AllowUserMount { + // Real cgroupfs available. + return opts + } if mnt.fs.FilesystemType().Name() == "cgroup" { splitPath := strings.Split(mountPath, "/") cgroupType := splitPath[len(splitPath)-1] opts += "," + cgroupType } + return opts } diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go index 39fbac987..47848c76b 100644 --- a/pkg/sentry/vfs/opath.go +++ b/pkg/sentry/vfs/opath.go @@ -121,7 +121,7 @@ func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, err Start: fd.vfsfd.vd, }) stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } @@ -134,6 +134,6 @@ func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vfsfd.vd, }) statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index e4fd55012..97b898aba 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -44,13 +44,10 @@ type ResolvingPath struct { start *Dentry pit fspath.Iterator - flags uint16 - mustBeDir bool // final file must be a directory? - mustBeDirOrig bool - symlinks uint8 // number of symlinks traversed - symlinksOrig uint8 - curPart uint8 // index into parts - numOrigParts uint8 + flags uint16 + mustBeDir bool // final file must be a directory? + symlinks uint8 // number of symlinks traversed + curPart uint8 // index into parts creds *auth.Credentials @@ -60,14 +57,9 @@ type ResolvingPath struct { nextStart *Dentry // ref held if not nil absSymlinkTarget fspath.Path - // ResolvingPath must track up to two relative paths: the "current" - // relative path, which is updated whenever a relative symlink is - // encountered, and the "original" relative path, which is updated from the - // current relative path by handleError() when resolution must change - // filesystems (due to reaching a mount boundary or absolute symlink) and - // overwrites the current relative path when Restart() is called. - parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator - origParts [1 + linux.MaxSymlinkTraversals]fspath.Iterator + // ResolvingPath tracks relative paths, which is updated whenever a relative + // symlink is encountered. + parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator } const ( @@ -120,6 +112,8 @@ var resolvingPathPool = sync.Pool{ }, } +// getResolvingPath gets a new ResolvingPath from the pool. Caller must call +// ResolvingPath.Release() when done. func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath { rp := resolvingPathPool.Get().(*ResolvingPath) rp.vfs = vfs @@ -132,17 +126,37 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat rp.flags |= rpflagsFollowFinalSymlink } rp.mustBeDir = pop.Path.Dir - rp.mustBeDirOrig = pop.Path.Dir rp.symlinks = 0 rp.curPart = 0 - rp.numOrigParts = 1 rp.creds = creds rp.parts[0] = pop.Path.Begin - rp.origParts[0] = pop.Path.Begin return rp } -func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) { +// Copy creates another ResolvingPath with the same state as the original. +// Copies are independent, using the copy does not change the original and +// vice-versa. +// +// Caller must call Resease() when done. +func (rp *ResolvingPath) Copy() *ResolvingPath { + copy := resolvingPathPool.Get().(*ResolvingPath) + *copy = *rp // All fields all shallow copiable. + + // Take extra reference for the copy if the original had them. + if copy.flags&rpflagsHaveStartRef != 0 { + copy.start.IncRef() + } + if copy.flags&rpflagsHaveMountRef != 0 { + copy.mount.IncRef() + } + // Reset error state. + copy.nextStart = nil + copy.nextMount = nil + return copy +} + +// Release decrements references if needed and returns the object to the pool. +func (rp *ResolvingPath) Release(ctx context.Context) { rp.root = VirtualDentry{} rp.decRefStartAndMount(ctx) rp.mount = nil @@ -240,25 +254,6 @@ func (rp *ResolvingPath) Advance() { } } -// Restart resets the stream of path components represented by rp to its state -// on entry to the current FilesystemImpl method. -func (rp *ResolvingPath) Restart(ctx context.Context) { - rp.pit = rp.origParts[rp.numOrigParts-1] - rp.mustBeDir = rp.mustBeDirOrig - rp.symlinks = rp.symlinksOrig - rp.curPart = rp.numOrigParts - 1 - copy(rp.parts[:], rp.origParts[:rp.numOrigParts]) - rp.releaseErrorState(ctx) -} - -func (rp *ResolvingPath) relpathCommit() { - rp.mustBeDirOrig = rp.mustBeDir - rp.symlinksOrig = rp.symlinks - rp.numOrigParts = rp.curPart + 1 - copy(rp.origParts[:rp.curPart], rp.parts[:]) - rp.origParts[rp.curPart] = rp.pit -} - // CheckRoot is called before resolving the parent of the Dentry d. If the // Dentry is contextually a VFS root, such that path resolution should treat // d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the @@ -405,11 +400,10 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef rp.nextMount = nil rp.nextStart = nil - // Commit the previous FileystemImpl's progress through the relative - // path. (Don't consume the path component that caused us to traverse + // Don't consume the path component that caused us to traverse // through the mount root - i.e. the ".." - because we still need to - // resolve the mount point's parent in the new FilesystemImpl.) - rp.relpathCommit() + // resolve the mount point's parent in the new FilesystemImpl. + // // Restart path resolution on the new Mount. Don't bother calling // rp.releaseErrorState() since we already set nextMount and nextStart // to nil above. @@ -425,9 +419,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.nextMount = nil // Consume the path component that represented the mount point. rp.Advance() - // Commit the previous FilesystemImpl's progress through the relative - // path. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true @@ -442,9 +433,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.Advance() // Prepend the symlink target to the relative path. rp.relpathPrepend(rp.absSymlinkTarget) - // Commit the previous FilesystemImpl's progress through the relative - // path, including the symlink target we just prepended. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 00f1847d8..87fdcf403 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -208,11 +208,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -230,11 +230,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede dentry: d, } rp.mount.IncRef() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return vd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, err } } @@ -252,7 +252,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } rp.mount.IncRef() name := rp.Component() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return parentVD, name, nil } if checkInvariants { @@ -261,7 +261,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, "", err } } @@ -292,7 +292,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential for { err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return nil } @@ -302,7 +302,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return err } @@ -331,7 +331,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -340,7 +340,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -366,7 +366,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -375,7 +375,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -425,7 +425,6 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential rp := vfs.getResolvingPath(creds, pop) if opts.Flags&linux.O_DIRECTORY != 0 { rp.mustBeDir = true - rp.mustBeDirOrig = true } // Ignore O_PATH for verity, as verity performs extra operations on the fd for verification. // The underlying filesystem that verity wraps opens the fd with O_PATH. @@ -444,7 +443,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) if opts.FileExec { if fd.Mount().Flags.NoExec { @@ -468,7 +467,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential return fd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -480,11 +479,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden for { target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return target, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -533,7 +532,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return nil } @@ -543,7 +542,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return err } @@ -569,7 +568,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.RmdirAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -578,7 +577,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -590,11 +589,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -606,11 +605,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential for { stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statx{}, err } } @@ -623,11 +622,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti for { statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statfs{}, err } } @@ -652,7 +651,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -661,7 +660,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -686,7 +685,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.UnlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -695,7 +694,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -707,7 +706,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return bep, nil } if checkInvariants { @@ -716,7 +715,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -729,7 +728,7 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede for { names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, nil } if err == syserror.ENOTSUP { @@ -737,11 +736,11 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -754,11 +753,11 @@ func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Creden for { val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -771,11 +770,11 @@ func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Creden for { err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -787,11 +786,11 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre for { err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } diff --git a/pkg/shim/utils/volumes.go b/pkg/shim/utils/volumes.go index 52a428179..cdcb88229 100644 --- a/pkg/shim/utils/volumes.go +++ b/pkg/shim/utils/volumes.go @@ -91,11 +91,9 @@ func isVolumePath(volume, path string) (bool, error) { // UpdateVolumeAnnotations add necessary OCI annotations for gvisor // volume optimization. func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { - var ( - uid string - err error - ) + var uid string if IsSandbox(s) { + var err error uid, err = podUID(s) if err != nil { // Skip if we can't get pod UID, because this doesn't work @@ -123,21 +121,18 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { } else { // This is a container. for i := range s.Mounts { - // An error is returned for sandbox if source - // annotation is not successfully applied, so - // it is guaranteed that the source annotation - // for sandbox has already been successfully - // applied at this point. + // An error is returned for sandbox if source annotation is not + // successfully applied, so it is guaranteed that the source annotation + // for sandbox has already been successfully applied at this point. // - // The volume name is unique inside a pod, so - // matching without podUID is fine here. + // The volume name is unique inside a pod, so matching without podUID + // is fine here. // - // TODO: Pass podUID down to shim for containers to do - // more accurate matching. + // TODO: Pass podUID down to shim for containers to do more accurate + // matching. if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes { - // gVisor requires the container mount type to match - // sandbox mount type. - s.Mounts[i].Type = v + // Container mount type must match the sandbox's mount type. + changeMountType(&s.Mounts[i], v) updated = true } } @@ -153,3 +148,22 @@ func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { } return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666) } + +func changeMountType(m *specs.Mount, newType string) { + m.Type = newType + + // OCI spec allows bind mounts to be specified in options only. So if new type + // is not bind, remove bind/rbind from options. + // + // "For bind mounts (when options include either bind or rbind), the type is + // a dummy, often "none" (not listed in /proc/filesystems)." + if newType != "bind" { + newOpts := make([]string, 0, len(m.Options)) + for _, opt := range m.Options { + if opt != "rbind" && opt != "bind" { + newOpts = append(newOpts, opt) + } + } + m.Options = newOpts + } +} diff --git a/pkg/shim/utils/volumes_test.go b/pkg/shim/utils/volumes_test.go index 3e02c6151..b25c53c73 100644 --- a/pkg/shim/utils/volumes_test.go +++ b/pkg/shim/utils/volumes_test.go @@ -47,60 +47,60 @@ func TestUpdateVolumeAnnotations(t *testing.T) { } for _, test := range []struct { - desc string + name string spec *specs.Spec expected *specs.Spec expectErr bool expectUpdate bool }{ { - desc: "volume annotations for sandbox", + name: "volume annotations for sandbox", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", - "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, }, }, expectUpdate: true, }, { - desc: "volume annotations for sandbox with legacy log path", + name: "volume annotations for sandbox with legacy log path", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLegacyLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", - "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, }, }, expectUpdate: true, }, { - desc: "tmpfs: volume annotations for container", + name: "tmpfs: volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -117,10 +117,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + containerTypeAnnotation: containerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ @@ -139,16 +139,16 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + containerTypeAnnotation: containerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expectUpdate: true, }, { - desc: "bind: volume annotations for container", + name: "bind: volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -159,10 +159,10 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + containerTypeAnnotation: containerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "container", + volumeKeyPrefix + testVolumeName + ".type": "bind", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ @@ -175,48 +175,48 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, Annotations: map[string]string{ - containerTypeAnnotation: containerTypeContainer, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + containerTypeAnnotation: containerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "container", + volumeKeyPrefix + testVolumeName + ".type": "bind", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expectUpdate: true, }, { - desc: "should not return error without pod log directory", + name: "should not return error without pod log directory", spec: &specs.Spec{ Annotations: map[string]string{ - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + containerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, expected: &specs.Spec{ Annotations: map[string]string{ - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", - "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", - "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + containerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", }, }, }, { - desc: "should return error if volume path does not exist", + name: "should return error if volume path does not exist", spec: &specs.Spec{ Annotations: map[string]string{ - sandboxLogDirAnnotation: testLogDirPath, - containerTypeAnnotation: containerTypeSandbox, - "dev.gvisor.spec.mount.notexist.share": "pod", - "dev.gvisor.spec.mount.notexist.type": "tmpfs", - "dev.gvisor.spec.mount.notexist.options": "ro", + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + volumeKeyPrefix + "notexist.share": "pod", + volumeKeyPrefix + "notexist.type": "tmpfs", + volumeKeyPrefix + "notexist.options": "ro", }, }, expectErr: true, }, { - desc: "no volume annotations for sandbox", + name: "no volume annotations for sandbox", spec: &specs.Spec{ Annotations: map[string]string{ sandboxLogDirAnnotation: testLogDirPath, @@ -231,7 +231,7 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, { - desc: "no volume annotations for container", + name: "no volume annotations for container", spec: &specs.Spec{ Mounts: []specs.Mount{ { @@ -271,8 +271,46 @@ func TestUpdateVolumeAnnotations(t *testing.T) { }, }, }, + { + name: "bind options removed", + spec: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, + }, + Mounts: []specs.Mount{ + { + Destination: "/dst", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro", "bind", "rbind"}, + }, + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + volumeKeyPrefix + testVolumeName + ".share": "pod", + volumeKeyPrefix + testVolumeName + ".type": "tmpfs", + volumeKeyPrefix + testVolumeName + ".options": "ro", + volumeKeyPrefix + testVolumeName + ".source": testVolumePath, + }, + Mounts: []specs.Mount{ + { + Destination: "/dst", + Type: "tmpfs", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + }, + expectUpdate: true, + }, } { - t.Run(test.desc, func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { bundle, err := ioutil.TempDir(dir, "test-bundle") if err != nil { t.Fatalf("Create test bundle: %v", err) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index d6c89c7e9..08d06e37b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -7,7 +7,6 @@ go_library( srcs = ["statefile.go"], visibility = ["//:sandbox"], deps = [ - "//pkg/binary", "//pkg/compressio", "//pkg/state/wire", ], diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index bdfb800fb..d27c8c8a8 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -48,6 +48,7 @@ import ( "compress/flate" "crypto/hmac" "crypto/sha256" + "encoding/binary" "encoding/json" "fmt" "hash" @@ -55,7 +56,6 @@ import ( "strings" "time" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/compressio" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -90,6 +90,13 @@ type WriteCloser interface { io.Closer } +func writeMetadataLen(w io.Writer, val uint64) error { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], val) + _, err := w.Write(buf[:]) + return err +} + // NewWriter returns a state data writer for a statefile. // // Note that the returned WriteCloser must be closed. @@ -127,7 +134,7 @@ func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser } // Metadata length. - if err := binary.WriteUint64(mw, binary.BigEndian, uint64(len(b))); err != nil { + if err := writeMetadataLen(mw, uint64(len(b))); err != nil { return nil, err } // Metadata bytes; io.MultiWriter will return a short write error if @@ -158,6 +165,14 @@ func MetadataUnsafe(r io.Reader) (map[string]string, error) { return metadata(r, nil) } +func readMetadataLen(r io.Reader) (uint64, error) { + var buf [8]byte + if _, err := io.ReadFull(r, buf[:]); err != nil { + return 0, err + } + return binary.BigEndian.Uint64(buf[:]), nil +} + // metadata validates the magic header and reads out the metadata from a state // data stream. func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { @@ -183,7 +198,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { } }() - metadataLen, err := binary.ReadUint64(r, binary.BigEndian) + metadataLen, err := readMetadataLen(r) if err != nil { return nil, err } diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index b2c5229e7..8b3a11c64 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -43,6 +43,7 @@ go_template( ], deps = [ ":sync", + "//pkg/gohacks", ], ) diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go index 82b676abf..9578c9c52 100644 --- a/pkg/sync/generic_seqatomic_unsafe.go +++ b/pkg/sync/generic_seqatomic_unsafe.go @@ -10,6 +10,7 @@ package seqatomic import ( "unsafe" + "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/sync" ) @@ -39,7 +40,7 @@ func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) // runtime.RaceDisable() doesn't actually stop the race detector, so it // can't help us here. Instead, call runtime.memmove directly, which is // not instrumented by the race detector. - sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { // This is ~40% faster for short reads than going through memmove. val = *ptr diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go index 158985709..39c766331 100644 --- a/pkg/sync/runtime_unsafe.go +++ b/pkg/sync/runtime_unsafe.go @@ -17,20 +17,6 @@ import ( "unsafe" ) -// Note that go:linkname silently doesn't work if the local name is exported, -// necessitating an indirection for exported functions. - -// Memmove is runtime.memmove, exported for SeqAtomicLoad/SeqAtomicTryLoad<T>. -// -//go:nosplit -func Memmove(to, from unsafe.Pointer, n uintptr) { - memmove(to, from, n) -} - -//go:linkname memmove runtime.memmove -//go:noescape -func memmove(to, from unsafe.Pointer, n uintptr) - // Gopark is runtime.gopark. Gopark calls unlockf(pointer to runtime.g, lock); // if unlockf returns true, Gopark blocks until Goready(pointer to runtime.g) // is called. unlockf and its callees must be nosplit and norace, since stack diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD index 5c38c783e..5f9164117 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomictest/BUILD @@ -18,6 +18,7 @@ go_library( name = "seqatomic", srcs = ["seqatomic_int_unsafe.go"], deps = [ + "//pkg/gohacks", "//pkg/sync", ], ) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index f979d22f0..e96ba50ae 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,4 +1,5 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:deps.bzl", "deps_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -21,8 +22,9 @@ go_library( "errors.go", "sock_err_list.go", "socketops.go", + "stdclock.go", + "stdclock_state.go", "tcpip.go", - "time_unsafe.go", "timer.go", ], visibility = ["//visibility:public"], @@ -33,6 +35,36 @@ go_library( ], ) +deps_test( + name = "netstack_deps_test", + allowed = [ + "@com_github_google_btree//:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + "@org_golang_x_time//rate:go_default_library", + ], + allowed_prefixes = [ + "//", + "@org_golang_x_sys//internal/unsafeheader", + ], + targets = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/packetsocket", + "//pkg/tcpip/link/qdisc/fifo", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + ], +) + go_test( name = "tcpip_test", size = "small", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index fef065b05..12c39dfa3 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -53,9 +53,8 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { t.Error("Not a valid IPv4 packet") } - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) + if !ipv4.IsChecksumValid() { + t.Errorf("Bad checksum, got = %d", ipv4.Checksum()) } for _, f := range checkers { @@ -400,18 +399,11 @@ func TCP(checkers ...TransportChecker) NetworkChecker { t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } - // Verify the checksum. tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) + payload := tcp.Payload() + payloadChecksum := header.Checksum(payload, 0) + if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) { + t.Errorf("Bad checksum, got = %d", tcp.Checksum()) } // Run the transport checkers. diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go index 52c22230e..33ff22a7b 100644 --- a/pkg/tcpip/hash/jenkins/jenkins.go +++ b/pkg/tcpip/hash/jenkins/jenkins.go @@ -42,26 +42,26 @@ func (s *Sum32) Reset() { *s = 0 } // Sum32 returns the hash value func (s *Sum32) Sum32() uint32 { - hash := *s + sCopy := *s - hash += (hash << 3) - hash ^= hash >> 11 - hash += hash << 15 + sCopy += sCopy << 3 + sCopy ^= sCopy >> 11 + sCopy += sCopy << 15 - return uint32(hash) + return uint32(sCopy) } // Write adds more data to the running hash. // // It never returns an error. func (s *Sum32) Write(data []byte) (int, error) { - hash := *s + sCopy := *s for _, b := range data { - hash += Sum32(b) - hash += hash << 10 - hash ^= hash >> 6 + sCopy += Sum32(b) + sCopy += sCopy << 10 + sCopy ^= sCopy >> 6 } - *s = hash + *s = sCopy return len(data), nil } diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 0bdc12d53..01240f5d0 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -52,6 +52,7 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -69,6 +70,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 3bc8b2b21..bf9ccbf1a 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) func TestIsValidUnicastEthernetAddress(t *testing.T) { @@ -142,7 +143,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { } func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { - addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a") + addr := testutil.MustParse6("ff02:304:506:708:90a:b0c:d0e:f1a") if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) } diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go index b6126d29a..575604928 100644 --- a/pkg/tcpip/header/igmp_test.go +++ b/pkg/tcpip/header/igmp_test.go @@ -18,8 +18,8 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) // TestIGMPHeader tests the functions within header.igmp @@ -46,7 +46,7 @@ func TestIGMPHeader(t *testing.T) { t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) } - if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + if got, want := igmpHeader.GroupAddress(), testutil.MustParse4("1.2.3.4"); got != want { t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) } @@ -71,7 +71,7 @@ func TestIGMPHeader(t *testing.T) { t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) } - groupAddress := tcpip.Address("\x04\x03\x02\x01") + groupAddress := testutil.MustParse4("4.3.2.1") igmpHeader.SetGroupAddress(groupAddress) if got := igmpHeader.GroupAddress(); got != groupAddress { t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index f588311e0..2be21ec75 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -178,6 +178,26 @@ const ( IPv4FlagDontFragment ) +// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined +// by RFC 3927 section 1. +var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as +// defined by RFC 5771 section 4. +var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00")) + if err != nil { + panic(err) + } + return subnet +}() + // IPv4EmptySubnet is the empty IPv4 subnet. var IPv4EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any)) @@ -423,6 +443,44 @@ func (b IPv4) IsValid(pktSize int) bool { return true } +// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4 +// link-local unicast address. +func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalUnicastSubnet.Contains(addr) +} + +// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4 +// link-local multicast address. +func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalMulticastSubnet.Contains(addr) +} + +// IsChecksumValid returns true iff the IPv4 header's checksum is valid. +func (b IPv4) IsChecksumValid() bool { + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + return b.CalculateChecksum() == 0xffff +} + // IsV4MulticastAddress determines if the provided address is an IPv4 multicast // address (range 224.0.0.0 to 239.255.255.255). The four most significant bits // will be 1110 = 0xe0. diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go index 6475cd694..c02fe898b 100644 --- a/pkg/tcpip/header/ipv4_test.go +++ b/pkg/tcpip/header/ipv4_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -177,3 +178,77 @@ func TestIPv4EncodeOptions(t *testing.T) { }) } } + +func TestIsV4LinkLocalUnicastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xa9\xfe\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xa9\xfe\xff\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xa9\xfd\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xa9\xff\x00\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV4LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xe0\x00\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xe0\x00\x00\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xdf\xff\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xe0\x00\x01\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index f2403978c..c3a0407ac 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -98,12 +98,27 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - // IPv6AllRoutersMulticastAddress is a link-local multicast group that - // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local + // multicast group that all IPv6 routers MUST join, as per RFC 4291, section + // 2.8. Packets destined to this address will reach the router on an + // interface. + // + // The address is ff01::2. + IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets // destined to this address will reach all routers on a link. // // The address is ff02::2. - IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers in a site. + // + // The address is ff05::2. + IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, // section 5: @@ -142,11 +157,6 @@ const ( // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, // within the byte holding the field, as per RFC 4291 section 2.7. ipv6MulticastAddressScopeMask = 0xF - - // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within - // a multicast IPv6 address that indicates the address has link-local scope, - // as per RFC 4291 section 2.7. - ipv6LinkLocalMulticastScope = 2 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -381,25 +391,25 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { return tcpip.Address(lladdrb[:]) } -// IsV6LinkLocalAddress determines if the provided address is an IPv6 -// link-local address (fe80::/10). -func IsV6LinkLocalAddress(addr tcpip.Address) bool { +// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6 +// link-local unicast address, as defined by RFC 4291 section 2.5.6. +func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool { if len(addr) != IPv6AddressSize { return false } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } -// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback -// address. +// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback +// address, as defined by RFC 4291 section 2.5.3. func IsV6LoopbackAddress(addr tcpip.Address) bool { return addr == IPv6Loopback } -// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 -// link-local multicast address. +// IsV6LinkLocalMulticastAddress returns true iff the provided address is an +// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { - return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope + return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope } // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier @@ -462,7 +472,7 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { case IsV6LinkLocalMulticastAddress(addr): return LinkLocalScope, nil - case IsV6LinkLocalAddress(addr): + case IsV6LinkLocalUnicastAddress(addr): return LinkLocalScope, nil default: @@ -520,3 +530,46 @@ func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) PrefixLen: IIDOffsetInIPv6Address * 8, } } + +// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by +// RFC 7346 section 2. +type IPv6MulticastScope uint8 + +// The various values for IPv6 multicast scopes, as per RFC 7346 section 2: +// +// +------+--------------------------+-------------------------+ +// | scop | NAME | REFERENCE | +// +------+--------------------------+-------------------------+ +// | 0 | Reserved | [RFC4291], RFC 7346 | +// | 1 | Interface-Local scope | [RFC4291], RFC 7346 | +// | 2 | Link-Local scope | [RFC4291], RFC 7346 | +// | 3 | Realm-Local scope | [RFC4291], RFC 7346 | +// | 4 | Admin-Local scope | [RFC4291], RFC 7346 | +// | 5 | Site-Local scope | [RFC4291], RFC 7346 | +// | 6 | Unassigned | | +// | 7 | Unassigned | | +// | 8 | Organization-Local scope | [RFC4291], RFC 7346 | +// | 9 | Unassigned | | +// | A | Unassigned | | +// | B | Unassigned | | +// | C | Unassigned | | +// | D | Unassigned | | +// | E | Global scope | [RFC4291], RFC 7346 | +// | F | Reserved | [RFC4291], RFC 7346 | +// +------+--------------------------+-------------------------+ +const ( + IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0) + IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1) + IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2) + IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3) + IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4) + IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5) + IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8) + IPv6GlobalMulticastScope = IPv6MulticastScope(0xE) + IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF) +) + +// V6MulticastScope returns the scope of a multicast address. +func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope { + return IPv6MulticastScope(addr[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index f10f446a6..89be84068 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -24,15 +24,17 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +var ( + linkLocalAddr = testutil.MustParse6("fe80::1") + linkLocalMulticastAddr = testutil.MustParse6("ff02::1") + uniqueLocalAddr1 = testutil.MustParse6("fc00::1") + uniqueLocalAddr2 = testutil.MustParse6("fd00::2") + globalAddr = testutil.MustParse6("a000::1") ) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { @@ -50,7 +52,7 @@ func TestEthernetAdddressToModifiedEUI64(t *testing.T) { } func TestLinkLocalAddr(t *testing.T) { - if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + if got, want := header.LinkLocalAddr(linkAddr), testutil.MustParse6("fe80::2:3ff:fe04:506"); got != want { t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) } } @@ -252,7 +254,7 @@ func TestIsV6LinkLocalMulticastAddress(t *testing.T) { } } -func TestIsV6LinkLocalAddress(t *testing.T) { +func TestIsV6LinkLocalUnicastAddress(t *testing.T) { tests := []struct { name string addr tcpip.Address @@ -287,8 +289,8 @@ func TestIsV6LinkLocalAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) } }) } @@ -373,3 +375,83 @@ func TestSolicitedNodeAddr(t *testing.T) { }) } } + +func TestV6MulticastScope(t *testing.T) { + tests := []struct { + addr tcpip.Address + want header.IPv6MulticastScope + }{ + { + addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6Reserved0MulticastScope, + }, + { + addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6InterfaceLocalMulticastScope, + }, + { + addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6LinkLocalMulticastScope, + }, + { + addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6RealmLocalMulticastScope, + }, + { + addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6AdminLocalMulticastScope, + }, + { + addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6SiteLocalMulticastScope, + }, + { + addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(6), + }, + { + addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(7), + }, + { + addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6OrganizationLocalMulticastScope, + }, + { + addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(9), + }, + { + addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(10), + }, + { + addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(11), + }, + { + addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(12), + }, + { + addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(13), + }, + { + addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6GlobalMulticastScope, + }, + { + addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6ReservedFMulticastScope, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { + if got := header.V6MulticastScope(test.addr); got != test.want { + t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want) + } + }) + } +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index d0a1a2492..1b5093e58 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) // TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. @@ -40,13 +41,13 @@ func TestNDPNeighborSolicit(t *testing.T) { // Test getting the Target Address. ns := NDPNeighborSolicit(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") if got := ns.TargetAddress(); got != addr { t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) } // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") ns.SetTargetAddress(addr2) if got := ns.TargetAddress(); got != addr2 { t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) @@ -69,7 +70,7 @@ func TestNDPNeighborAdvert(t *testing.T) { // Test getting the Target Address. na := NDPNeighborAdvert(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") if got := na.TargetAddress(); got != addr { t.Errorf("got TargetAddress = %s, want %s", got, addr) } @@ -90,7 +91,7 @@ func TestNDPNeighborAdvert(t *testing.T) { } // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") na.SetTargetAddress(addr2) if got := na.TargetAddress(); got != addr2 { t.Errorf("got TargetAddress = %s, want %s", got, addr2) @@ -277,7 +278,7 @@ func TestOpts(t *testing.T) { } const validLifetimeSeconds = 16909060 - const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18") + address := testutil.MustParse6("90a:b0c:d0e:f10:1112:1314:1516:1718") expectedRDNSSBytes := [...]byte{ // Type, Length diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index adc835d30..0df517000 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -216,104 +216,104 @@ const ( TCPDefaultMSS = 536 ) -// SourcePort returns the "source port" field of the tcp header. +// SourcePort returns the "source port" field of the TCP header. func (b TCP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[TCPSrcPortOffset:]) } -// DestinationPort returns the "destination port" field of the tcp header. +// DestinationPort returns the "destination port" field of the TCP header. func (b TCP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[TCPDstPortOffset:]) } -// SequenceNumber returns the "sequence number" field of the tcp header. +// SequenceNumber returns the "sequence number" field of the TCP header. func (b TCP) SequenceNumber() uint32 { return binary.BigEndian.Uint32(b[TCPSeqNumOffset:]) } -// AckNumber returns the "ack number" field of the tcp header. +// AckNumber returns the "ack number" field of the TCP header. func (b TCP) AckNumber() uint32 { return binary.BigEndian.Uint32(b[TCPAckNumOffset:]) } -// DataOffset returns the "data offset" field of the tcp header. The return +// DataOffset returns the "data offset" field of the TCP header. The return // value is the length of the TCP header in bytes. func (b TCP) DataOffset() uint8 { return (b[TCPDataOffset] >> 4) * 4 } -// Payload returns the data in the tcp packet. +// Payload returns the data in the TCP packet. func (b TCP) Payload() []byte { return b[b.DataOffset():] } -// Flags returns the flags field of the tcp header. +// Flags returns the flags field of the TCP header. func (b TCP) Flags() TCPFlags { return TCPFlags(b[TCPFlagsOffset]) } -// WindowSize returns the "window size" field of the tcp header. +// WindowSize returns the "window size" field of the TCP header. func (b TCP) WindowSize() uint16 { return binary.BigEndian.Uint16(b[TCPWinSizeOffset:]) } -// Checksum returns the "checksum" field of the tcp header. +// Checksum returns the "checksum" field of the TCP header. func (b TCP) Checksum() uint16 { return binary.BigEndian.Uint16(b[TCPChecksumOffset:]) } -// UrgentPointer returns the "urgent pointer" field of the tcp header. +// UrgentPointer returns the "urgent pointer" field of the TCP header. func (b TCP) UrgentPointer() uint16 { return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:]) } -// SetSourcePort sets the "source port" field of the tcp header. +// SetSourcePort sets the "source port" field of the TCP header. func (b TCP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port) } -// SetDestinationPort sets the "destination port" field of the tcp header. +// SetDestinationPort sets the "destination port" field of the TCP header. func (b TCP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port) } -// SetChecksum sets the checksum field of the tcp header. +// SetChecksum sets the checksum field of the TCP header. func (b TCP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum) } -// SetDataOffset sets the data offset field of the tcp header. headerLen should +// SetDataOffset sets the data offset field of the TCP header. headerLen should // be the length of the TCP header in bytes. func (b TCP) SetDataOffset(headerLen uint8) { b[TCPDataOffset] = (headerLen / 4) << 4 } -// SetSequenceNumber sets the sequence number field of the tcp header. +// SetSequenceNumber sets the sequence number field of the TCP header. func (b TCP) SetSequenceNumber(seqNum uint32) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum) } -// SetAckNumber sets the ack number field of the tcp header. +// SetAckNumber sets the ack number field of the TCP header. func (b TCP) SetAckNumber(ackNum uint32) { binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum) } -// SetFlags sets the flags field of the tcp header. +// SetFlags sets the flags field of the TCP header. func (b TCP) SetFlags(flags uint8) { b[TCPFlagsOffset] = flags } -// SetWindowSize sets the window size field of the tcp header. +// SetWindowSize sets the window size field of the TCP header. func (b TCP) SetWindowSize(rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// SetUrgentPoiner sets the window size field of the tcp header. +// SetUrgentPoiner sets the window size field of the TCP header. func (b TCP) SetUrgentPoiner(urgentPointer uint16) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer) } -// CalculateChecksum calculates the checksum of the tcp segment. +// CalculateChecksum calculates the checksum of the TCP segment. // partialChecksum is the checksum of the network-layer pseudo-header // and the checksum of the segment data. func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { @@ -321,6 +321,13 @@ func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { return Checksum(b[:b.DataOffset()], partialChecksum) } +// IsChecksumValid returns true iff the TCP header's checksum is valid. +func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool { + xsum := PseudoHeaderChecksum(TCPProtocolNumber, src, dst, uint16(b.DataOffset())+payloadLength) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + // Options returns a slice that holds the unparsed TCP options in the segment. func (b TCP) Options() []byte { return b[TCPMinimumSize:b.DataOffset()] @@ -340,7 +347,7 @@ func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// Encode encodes all the fields of the tcp header. +// Encode encodes all the fields of the TCP header. func (b TCP) Encode(t *TCPFields) { b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort) @@ -350,7 +357,7 @@ func (b TCP) Encode(t *TCPFields) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer) } -// EncodePartial updates a subset of the fields of the tcp header. It is useful +// EncodePartial updates a subset of the fields of the TCP header. It is useful // in cases when similar segments are produced. func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. @@ -374,7 +381,7 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 } // ParseSynOptions parses the options received in a SYN segment and returns the -// relevant ones. opts should point to the option part of the TCP Header. +// relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { limit := len(opts) diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 98bdd29db..ae9d167ff 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -64,17 +64,17 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. +// SourcePort returns the "source port" field of the UDP header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "destination port" field of the udp header. +// DestinationPort returns the "destination port" field of the UDP header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "length" field of the udp header. +// Length returns the "length" field of the UDP header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } @@ -84,39 +84,46 @@ func (b UDP) Payload() []byte { return b[UDPMinimumSize:] } -// Checksum returns the "checksum" field of the udp header. +// Checksum returns the "checksum" field of the UDP header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the udp header. +// SetSourcePort sets the "source port" field of the UDP header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the udp header. +// SetDestinationPort sets the "destination port" field of the UDP header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the udp header. +// SetChecksum sets the "checksum" field of the UDP header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the udp header. +// SetLength sets the "length" field of the UDP header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// CalculateChecksum calculates the checksum of the udp packet, given the +// CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. return Checksum(b[:UDPMinimumSize], partialChecksum) } -// Encode encodes all the fields of the udp header. +// IsChecksumValid returns true iff the UDP header's checksum is valid. +func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool { + xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length()) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + +// Encode encodes all the fields of the UDP header. func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index cd76272de..ef9126deb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -30,7 +30,6 @@ import ( type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO Route stack.RouteInfo } @@ -124,6 +123,9 @@ func (q *queue) RemoveNotify(handle *NotificationHandle) { q.notify = notify } +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.GSOEndpoint = (*Endpoint)(nil) + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -131,6 +133,7 @@ type Endpoint struct { mtu uint32 linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities + SupportedGSOKind stack.SupportedGSO // Outbound packet queue. q *queue @@ -212,11 +215,16 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.LinkEPCapabilities } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (*Endpoint) GSOMaxSize() uint32 { return 1 << 15 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + return e.SupportedGSOKind +} + // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*Endpoint) MaxHeaderLength() uint16 { @@ -229,11 +237,10 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { p := PacketInfo{ Pkt: pkt, Proto: protocol, - GSO: gso, Route: r, } @@ -243,13 +250,12 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ Pkt: pkt, Proto: protocol, - GSO: gso, Route: r, } diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index d873766a6..b427c6170 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -61,20 +61,20 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) - return e.Endpoint.WritePacket(r, gso, proto, pkt) + return e.Endpoint.WritePacket(r, proto, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) } - return e.Endpoint.WritePackets(r, gso, pkts, proto) + return e.Endpoint.WritePackets(r, pkts, proto) } // MaxHeaderLength implements stack.LinkEndpoint. diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index f042df82e..d971194e6 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,7 +14,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/binary", "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 6be945116..bddb1d0a2 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,7 +45,6 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -98,6 +97,9 @@ func (p PacketDispatchMode) String() string { } } +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -134,6 +136,9 @@ type endpoint struct { // wg keeps track of running goroutines. wg sync.WaitGroup + + // gsoKind is the supported kind of GSO. + gsoKind stack.SupportedGSO } // Options specify the details about the fd-based endpoint to be created. @@ -255,9 +260,9 @@ func New(opts *Options) (stack.LinkEndpoint, error) { if isSocket { if opts.GSOMaxSize != 0 { if opts.SoftwareGSOEnabled { - e.caps |= stack.CapabilitySoftwareGSO + e.gsoKind = stack.SWGSOSupported } else { - e.caps |= stack.CapabilityHardwareGSO + e.gsoKind = stack.HWGSOSupported } e.gsoMaxSize = opts.GSOMaxSize } @@ -403,6 +408,35 @@ type virtioNetHdr struct { csumOffset uint16 } +// marshal serializes h to a newly-allocated byte slice, in little-endian byte +// order. +// +// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used +// for general serialization. This makes it difficult to use go-marshal for +// virtio types, as go-marshal implicitly uses the native byte ordering. +func (h *virtioNetHdr) marshal() []byte { + buf := [virtioNetHdrSize]byte{ + 0: byte(h.flags), + 1: byte(h.gsoType), + + // Manually lay out the fields in little-endian byte order. Little endian => + // least significant bit goes to the lower address. + + 2: byte(h.hdrLen), + 3: byte(h.hdrLen >> 8), + + 4: byte(h.gsoSize), + 5: byte(h.gsoSize >> 8), + + 6: byte(h.csumStart), + 7: byte(h.csumStart >> 8), + + 8: byte(h.csumOffset), + 9: byte(h.csumOffset >> 8), + } + return buf[:] +} + // These constants are declared in linux/virtio_net.h. const ( _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1 @@ -433,7 +467,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.hdrSize > 0 { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } @@ -441,29 +475,29 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} - if gso != nil { + if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) - if gso.NeedsCsum { + if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM - vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen - vnetHdr.csumOffset = gso.CsumOffset + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset } - if gso.Type != stack.GSONone && uint16(pkt.Data().Size()) > gso.MSS { - switch gso.Type { + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 case stack.GSOTCPv6: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 default: - panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) } - vnetHdr.gsoSize = gso.MSS + vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf := vnetHdr.marshal() builder.Add(vnetHdrBuf) } @@ -482,9 +516,9 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp } var vnetHdrBuf []byte - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} - if pkt.GSOOptions != nil { + if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM @@ -503,7 +537,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf = vnetHdr.marshal() } var builder iovec.Builder @@ -540,7 +574,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { // Preallocate to avoid repeated reallocation as we append to batch. // batchSz is 47 because when SWGSO is in use then a single 65KB TCP // segment can get split into 46 segments of 1420 bytes and a single 216 @@ -602,11 +636,16 @@ func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { } } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// SupportsHWGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + return e.gsoKind +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (e *endpoint) ARPHardwareType() header.ARPHardwareType { if e.hdrSize > 0 { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 1e40f3fef..8aad338b6 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -207,18 +207,17 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u // Write. want := append(append(buffer.View(nil), b...), payload...) - var gso *stack.GSO + const l3HdrLen = header.IPv6MinimumSize if gsoMaxSize != 0 { - gso = &stack.GSO{ + pkt.GSOOptions = stack.GSO{ Type: stack.GSOTCPv6, NeedsCsum: true, CsumOffset: csumOffset, MSS: gsoMSS, - MaxSize: gsoMaxSize, - L3HdrLen: header.IPv4MaximumHeaderSize, + L3HdrLen: l3HdrLen, } } - if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -235,7 +234,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) } - csumStart := header.EthernetMinimumSize + gso.L3HdrLen + const csumStart = header.EthernetMinimumSize + l3HdrLen if vnetHdr.csumStart != csumStart { t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) } @@ -243,7 +242,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) } gsoType := uint8(0) - if int(gso.MSS) < plen { + if plen > gsoMSS { gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 } if vnetHdr.gsoType != gsoType { @@ -333,7 +332,7 @@ func TestPreserveSrcAddress(t *testing.T) { ReserveHeaderBytes: header.EthernetMinimumSize, Data: buffer.VectorisedView{}, }) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index a7adf822b..4b7ef3aac 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -128,7 +128,7 @@ type readVDispatcher struct { func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) return d, nil } @@ -212,7 +212,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { bufs: make([]*iovecBuffer, MaxMsgsPerRecv), msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported for i := range d.bufs { d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 691467870..7012d8829 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // Construct data as the unparsed portion for the loopback packet. data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 668f72eee..3e2a1aa94 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,20 +87,20 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, &tcpip.ErrNoRoute{} } - return endpoint.WritePackets(r, gso, pkts, protocol) + return endpoint.WritePackets(r, pkts, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, gso, protocol, pkt) + return endpoint.WritePacket(r, protocol, pkt) } return &tcpip.ErrNoRoute{} } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 5806f7fdf..040e3a35b 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -54,7 +54,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { var packetRoute stack.RouteInfo packetRoute.RemoteAddress = dstIP - endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) + endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -76,7 +76,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { pkt.TransportHeader().Push(1)[0] = 0xFA var packetRoute stack.RouteInfo packetRoute.RemoteAddress = dstIP - endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) + endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 97ad9fdd5..3e816b0c7 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -113,13 +113,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return e.child.WritePacket(r, gso, protocol, pkt) +func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + return e.child.WritePacket(r, protocol, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return e.child.WritePackets(r, gso, pkts, protocol) +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return e.child.WritePackets(r, pkts, protocol) } // Wait implements stack.LinkEndpoint. @@ -135,6 +135,14 @@ func (e *Endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.SupportedGSO() + } + return stack.GSONotSupported +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.child.ARPHardwareType() diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index 6cbe18a56..e01837e2d 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -35,16 +35,16 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) - return e.Endpoint.WritePacket(r, gso, protocol, pkt) + return e.Endpoint.WritePacket(r, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } - return e.Endpoint.WritePackets(r, gso, pkts, proto) + return e.Endpoint.WritePackets(r, pkts, proto) } diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 21fb87757..5030b6ba1 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -66,7 +66,7 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var pkts stack.PacketBufferList pkts.PushBack(pkt) e.deliverPackets(r, proto, pkts) @@ -74,7 +74,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := pkts.Len() e.deliverPackets(r, proto, pkts) return n, nil diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 128ef6e87..b1a28491d 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -25,6 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + // endpoint represents a LinkEndpoint which implements a FIFO queue for all // outgoing packets. endpoint can have 1 or more underlying queueDispatchers. // All outgoing packets are consistenly hashed to a single underlying queue @@ -91,7 +94,7 @@ func (q *queueDispatcher) dispatchLoop() { } // We pass a protocol of zero here because each packet carries its // NetworkProtocol. - q.lower.WritePackets(stack.RouteInfo{}, nil /* gso */, batch, 0 /* protocol */) + q.lower.WritePackets(stack.RouteInfo{}, batch, 0 /* protocol */) for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() { batch.Remove(pkt) } @@ -141,7 +144,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.lower.LinkAddress() } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { if gso, ok := e.lower.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() @@ -149,13 +152,21 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.SupportedGSO() + } + return stack.GSONotSupported +} + // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - // WritePacket caller's do not set the following fields in PacketBuffer - // so we populate them here. - pkt.EgressRoute = r - pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = protocol +// +// The packet must have the following fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] if !d.q.enqueue(pkt) { return &tcpip.ErrNoBufferSpace{} @@ -166,12 +177,12 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. // -// Being a batch API, each packet in pkts should have the following -// fields populated: +// Each packet in the packet buffer list must have the following fields +// populated: // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { enqueued := 0 for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index d8d0b16b2..df9a0b90a 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) views := pkt.Views() @@ -220,7 +220,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index d4b3ddd5c..0f72d4e95 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -281,7 +281,7 @@ func TestSimpleSend(t *testing.T) { copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -351,7 +351,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -407,7 +407,7 @@ func TestFillTxQueue(t *testing.T) { Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -425,7 +425,7 @@ func TestFillTxQueue(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -453,7 +453,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -476,7 +476,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -494,7 +494,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -520,7 +520,7 @@ func TestFillTxMemory(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -539,7 +539,7 @@ func TestFillTxMemory(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -566,7 +566,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -581,7 +581,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buffer.NewView(bufferSize).ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -593,7 +593,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 7aaee3d13..2d6a3a833 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -139,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dumpPacket(directionRecv, nil, protocol, pkt) + e.dumpPacket(directionRecv, protocol, pkt) e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -148,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } -func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(e.logPrefix, dir, protocol, pkt, gso) + logPacket(e.logPrefix, dir, protocol, pkt) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Size() @@ -187,22 +187,22 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.dumpPacket(directionSend, gso, protocol, pkt) - return e.Endpoint.WritePacket(r, gso, protocol, pkt) +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + e.dumpPacket(directionSend, protocol, pkt) + return e.Endpoint.WritePacket(r, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.dumpPacket(directionSend, gso, protocol, pkt) + e.dumpPacket(directionSend, protocol, pkt) } - return e.Endpoint.WritePackets(r, gso, pkts, protocol) + return e.Endpoint.WritePackets(r, pkts, protocol) } -func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -411,8 +411,8 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe return } - if gso != nil { - details += fmt.Sprintf(" gso: %+v", gso) + if pkt.GSOOptions.Type != stack.GSONone { + details += fmt.Sprintf(" gso: %#v", pkt.GSOOptions) } log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index ce5113746..a95602aa5 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -108,12 +108,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, protocol, pkt) + err := e.lower.WritePacket(r, protocol, pkt) e.writeGate.Leave() return err } @@ -121,12 +121,12 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if !e.writeGate.Enter() { return pkts.Len(), nil } - n, err := e.lower.WritePackets(r, gso, pkts, protocol) + n, err := e.lower.WritePackets(r, pkts, protocol) e.writeGate.Leave() return n, err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index e368a9eaa..a71400ee9 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (e *countedEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { e.writeCount += pkts.Len() return pkts.Len(), nil } @@ -98,21 +98,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index fa8814bac..7b1ff44f4 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -21,6 +21,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d59d678b2..6905b9ccb 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -33,6 +33,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7ae38d684..0efa3a926 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -136,7 +136,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (*endpoint) Close() {} -func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { +func (*endpoint) WritePacket(*stack.Route, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -146,7 +146,7 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*endpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { return 0, &tcpip.ErrNotSupported{} } @@ -222,7 +222,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // // Send the packet to the (new) target hardware address on the same // hardware on which the request was received. - if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil { + if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), ProtocolNumber, respPkt); err != nil { stats.outgoingRepliesDropped.Increment() } else { stats.outgoingRepliesSent.Increment() @@ -355,7 +355,7 @@ func (e *endpoint) sendARPRequest(localAddr, targetAddr tcpip.Address, remoteLin } stats := e.stats.arp - if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 018d6a578..94209b026 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -30,20 +30,16 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) const ( nicID = 1 - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - unknownAddr = tcpip.Address("\x0a\x00\x00\x03") - defaultChannelSize = 1 defaultMTU = 65536 @@ -54,6 +50,12 @@ const ( eventChanSize = 32 ) +var ( + stackAddr = testutil.MustParse4("10.0.0.1") + remoteAddr = testutil.MustParse4("10.0.0.2") + unknownAddr = testutil.MustParse4("10.0.0.3") +) + type eventType uint8 const ( @@ -449,12 +451,12 @@ type testLinkEndpoint struct { writeErr tcpip.Error } -func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) + return t.LinkEndpoint.WritePacket(r, protocol, pkt) } func TestLinkAddressRequest(t *testing.T) { diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD index d21b4c7ef..fd944ce99 100644 --- a/pkg/tcpip/network/internal/ip/BUILD +++ b/pkg/tcpip/network/internal/ip/BUILD @@ -6,6 +6,7 @@ go_library( name = "ip", srcs = [ "duplicate_address_detection.go", + "errors.go", "generic_multicast_protocol.go", "stats.go", ], diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go new file mode 100644 index 000000000..50fabfd79 --- /dev/null +++ b/pkg/tcpip/network/internal/ip/errors.go @@ -0,0 +1,77 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// ForwardingError represents an error that occured while trying to forward +// a packet. +type ForwardingError interface { + isForwardingError() + fmt.Stringer +} + +// ErrTTLExceeded indicates that the received packet's TTL has been exceeded. +type ErrTTLExceeded struct{} + +func (*ErrTTLExceeded) isForwardingError() {} + +func (*ErrTTLExceeded) String() string { return "ttl exceeded" } + +// ErrIPOptProblem indicates the received packet had a problem with an IP +// option. +type ErrIPOptProblem struct{} + +func (*ErrIPOptProblem) isForwardingError() {} + +func (*ErrIPOptProblem) String() string { return "ip option problem" } + +// ErrLinkLocalSourceAddress indicates the received packet had a link-local +// source address. +type ErrLinkLocalSourceAddress struct{} + +func (*ErrLinkLocalSourceAddress) isForwardingError() {} + +func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" } + +// ErrLinkLocalDestinationAddress indicates the received packet had a link-local +// destination address. +type ErrLinkLocalDestinationAddress struct{} + +func (*ErrLinkLocalDestinationAddress) isForwardingError() {} + +func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" } + +// ErrNoRoute indicates the Netstack couldn't find a route for the +// received packet. +type ErrNoRoute struct{} + +func (*ErrNoRoute) isForwardingError() {} + +func (*ErrNoRoute) String() string { return "no route" } + +// ErrOther indicates the packet coould not be forwarded for a reason +// captured by the contained error. +type ErrOther struct { + Err tcpip.Error +} + +func (*ErrOther) isForwardingError() {} + +func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index b9f129728..d22974b12 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ip holds IPv4/IPv6 common utilities. package ip import ( @@ -156,14 +155,6 @@ type GenericMulticastProtocolOptions struct { // // Unsolicited reports are transmitted when a group is newly joined. MaxUnsolicitedReportDelay time.Duration - - // AllNodesAddress is a multicast address that all nodes on a network should - // be a member of. - // - // This address will not have the generic multicast protocol performed on it; - // it will be left in the non member/listener state, and packets will never - // be sent for it. - AllNodesAddress tcpip.Address } // MulticastGroupProtocol is a multicast group protocol whose core state machine @@ -188,6 +179,10 @@ type MulticastGroupProtocol interface { // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) tcpip.Error + + // ShouldPerformProtocol returns true iff the protocol should be performed for + // the specified group. + ShouldPerformProtocol(tcpip.Address) bool } // GenericMulticastProtocolState is the per interface generic multicast protocol @@ -455,20 +450,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t info.lastToSendReport = false - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { info.state = idleMember return } @@ -537,20 +519,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } @@ -627,20 +596,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 381460c82..0b51563cd 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -43,6 +43,8 @@ type mockMulticastGroupProtocolProtectedFields struct { type mockMulticastGroupProtocol struct { t *testing.T + skipProtocolAddress tcpip.Address + mu mockMulticastGroupProtocolProtectedFields } @@ -165,6 +167,11 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip return nil } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + return groupAddress != m.skipProtocolAddress +} + func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { m.mu.Lock() defer m.mu.Unlock() @@ -193,10 +200,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr cmp.FilterPath( func(p cmp.Path) bool { switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress": return true + default: + return false } - return false }, cmp.Ignore(), ), @@ -225,14 +233,13 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(0)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) // Joining a group should send a report immediately and another after @@ -279,14 +286,13 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(1)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) mgp.joinGroup(test.addr) @@ -356,14 +362,13 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(2)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -446,14 +451,13 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -574,14 +578,13 @@ func TestJoinCount(t *testing.T) { } func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index b6f39ddb1..392f0b0c7 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -18,70 +18,114 @@ import "gvisor.dev/gvisor/pkg/tcpip" // LINT.IfChange(MultiCounterIPStats) +// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter +// may have several versions. +type MultiCounterIPForwardingStats struct { + // Unrouteable is the number of IP packets received which were dropped + // because the netstack could not construct a route to their + // destination. + Unrouteable tcpip.MultiCounterStat + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL tcpip.MultiCounterStat + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource tcpip.MultiCounterStat + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination tcpip.MultiCounterStat + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors tcpip.MultiCounterStat +} + // MultiCounterIPStats holds IP statistics, each counter may have several // versions. type MultiCounterIPStats struct { - // PacketsReceived is the total number of IP packets received from the link + // PacketsReceived is the number of IP packets received from the link // layer. PacketsReceived tcpip.MultiCounterStat - // DisabledPacketsReceived is the total number of IP packets received from the - // link layer when the IP layer is disabled. + // DisabledPacketsReceived is the number of IP packets received from + // the link layer when the IP layer is disabled. DisabledPacketsReceived tcpip.MultiCounterStat - // InvalidDestinationAddressesReceived is the total number of IP packets + // InvalidDestinationAddressesReceived is the number of IP packets // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived tcpip.MultiCounterStat - // InvalidSourceAddressesReceived is the total number of IP packets received - // with a source address that should never have been received on the wire. + // InvalidSourceAddressesReceived is the number of IP packets received + // with a source address that should never have been received on the + // wire. InvalidSourceAddressesReceived tcpip.MultiCounterStat - // PacketsDelivered is the total number of incoming IP packets that are + // PacketsDelivered is the number of incoming IP packets that are // successfully delivered to the transport layer. PacketsDelivered tcpip.MultiCounterStat - // PacketsSent is the total number of IP packets sent via WritePacket. + // PacketsSent is the number of IP packets sent via WritePacket. PacketsSent tcpip.MultiCounterStat - // OutgoingPacketErrors is the total number of IP packets which failed to + // OutgoingPacketErrors is the number of IP packets which failed to // write to a link-layer endpoint. OutgoingPacketErrors tcpip.MultiCounterStat - // MalformedPacketsReceived is the total number of IP Packets that were + // MalformedPacketsReceived is the number of IP Packets that were // dropped due to the IP packet header failing validation checks. MalformedPacketsReceived tcpip.MultiCounterStat - // MalformedFragmentsReceived is the total number of IP Fragments that were + // MalformedFragmentsReceived is the number of IP Fragments that were // dropped due to the fragment failing validation checks. MalformedFragmentsReceived tcpip.MultiCounterStat - // IPTablesPreroutingDropped is the total number of IP packets dropped in the + // IPTablesPreroutingDropped is the number of IP packets dropped in the // Prerouting chain. IPTablesPreroutingDropped tcpip.MultiCounterStat - // IPTablesInputDropped is the total number of IP packets dropped in the Input - // chain. + // IPTablesInputDropped is the number of IP packets dropped in the + // Input chain. IPTablesInputDropped tcpip.MultiCounterStat - // IPTablesOutputDropped is the total number of IP packets dropped in the + // IPTablesOutputDropped is the number of IP packets dropped in the // Output chain. IPTablesOutputDropped tcpip.MultiCounterStat - // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out - // of IPStats. + // IPTablesPostroutingDropped is the number of IP packets dropped in + // the Postrouting chain. + IPTablesPostroutingDropped tcpip.MultiCounterStat + + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option + // stats out of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. OptionTimestampReceived tcpip.MultiCounterStat - // OptionRecordRouteReceived is the number of Record Route options seen. + // OptionRecordRouteReceived is the number of Record Route options + // seen. OptionRecordRouteReceived tcpip.MultiCounterStat - // OptionRouterAlertReceived is the number of Router Alert options seen. + // OptionRouterAlertReceived is the number of Router Alert options + // seen. OptionRouterAlertReceived tcpip.MultiCounterStat // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived tcpip.MultiCounterStat + + // Forwarding collects stats related to IP forwarding. + Forwarding MultiCounterIPForwardingStats +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { + m.Unrouteable.Init(a.Unrouteable, b.Unrouteable) + m.Errors.Init(a.Errors, b.Errors) + m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource) + m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination) + m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) } // Init sets internal counters to track a and b counters. @@ -98,10 +142,12 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) + m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped) m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived) m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived) m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived) m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) + m.Forwarding.Init(&a.Forwarding, &b.Forwarding) } // LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index f5fa77b65..e2cf24b67 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -64,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } // WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if ep.allowPackets == 0 { return ep.err } @@ -74,11 +74,11 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip } // WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { var n int for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + if err := ep.WritePacket(r, protocol, pkt); err != nil { return n, err } n++ diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a4edc69c7..74aad126c 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "fmt" "strings" "testing" @@ -29,23 +30,25 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -const ( - localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") - remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") - ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") - ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") - ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") - localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") - ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") - ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - nicID = 1 +const nicID = 1 + +var ( + localIPv4Addr = testutil.MustParse4("10.0.0.1") + remoteIPv4Addr = testutil.MustParse4("10.0.0.2") + ipv4SubnetAddr = testutil.MustParse4("10.0.0.0") + ipv4SubnetMask = testutil.MustParse4("255.255.255.0") + ipv4Gateway = testutil.MustParse4("10.0.0.3") + localIPv6Addr = testutil.MustParse6("a00::1") + remoteIPv6Addr = testutil.MustParse6("a00::2") + ipv6SubnetAddr = testutil.MustParse6("a00::") + ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00") + ipv6Gateway = testutil.MustParse6("a00::3") ) var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ @@ -180,7 +183,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -202,7 +205,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } @@ -323,7 +326,7 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -588,7 +591,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -1015,7 +1018,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -1938,3 +1941,80 @@ func TestICMPInclusionSize(t *testing.T) { }) } } + +func TestJoinLeaveAllRoutersGroup(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + protoFactory stack.NetworkProtocolFactory + allRoutersAddr tcpip.Address + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + protoFactory: ipv4.NewProtocol, + allRoutersAddr: header.IPv4AllRoutersGroup, + }, + { + name: "IPv6 Interface Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, + }, + { + name: "IPv6 Link Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, + }, + { + name: "IPv6 Site Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, nicDisabled := range [...]bool{true, false} { + t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + opts := stack.NICOptions{Disabled: nicDisabled} + if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) + } + + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if !got { + t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, false); err != nil { + t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 5e7f10f4b..7ee0495d9 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -45,6 +45,7 @@ go_test( "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1525f15db..c8ed1ce79 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet return } - // Skip the ip header, then deliver the error. - pkt.Data().TrimFront(hlen) + // Keep needed information before trimming header. p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) + dstAddr := hdr.DestinationAddress() + // Skip the ip header, then deliver the error. + pkt.Data().DeleteFront(hlen) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4DstUnreachable: received.dstUnreachable.Increment() - pkt.Data().TrimFront(header.ICMPv4MinimumSize) - switch h.Code() { + mtu := h.MTU() + code := h.Code() + pkt.Data().DeleteFront(header.ICMPv4MinimumSize) + switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: - networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } @@ -442,6 +446,23 @@ func (r *icmpReasonParamProblem) isForwarding() bool { return r.forwarding } +// icmpReasonNetworkUnreachable is an error in which the network specified in +// the internet destination field of the datagram is unreachable. +type icmpReasonNetworkUnreachable struct{} + +func (*icmpReasonNetworkUnreachable) isICMPReason() {} +func (*icmpReasonNetworkUnreachable) isForwarding() bool { + // If we hit a Net Unreachable error, then we know we are operating as + // a router. As per RFC 792 page 5, Destination Unreachable Message, + // + // If, according to the information in the gateway's routing tables, + // the network specified in the internet destination field of a + // datagram is unreachable, e.g., the distance to the network is + // infinity, the gateway may send a destination unreachable message to + // the internet source host of the datagram. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -610,6 +631,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetworkUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4NetUnreachable) + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) @@ -629,7 +654,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( - nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, TTL: route.DefaultTTL(), diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index f3fc1c87e..3ce499298 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -126,6 +126,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error { return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2236 section 6 page 10, + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + return groupAddress != header.IPv4AllSystems +} + // init sets up an igmpState struct, and is required to be called before using // a new igmpState. // @@ -137,7 +148,6 @@ func (igmp *igmpState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: igmp, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { @@ -331,7 +341,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip } sentStats := igmp.ep.stats.igmp.packetsSent - if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), ProtocolNumber, pkt); err != nil { sentStats.dropped.Increment() return false, err } diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index e5e1b89cc..4bd6f462e 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -26,18 +26,22 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") - multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 defaultTTL = 1 defaultPrefixLength = 24 ) +var ( + stackAddr = testutil.MustParse4("10.0.0.1") + remoteAddr = testutil.MustParse4("10.0.0.2") + multicastAddr = testutil.MustParse4("224.0.0.3") +) + // validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet // sent to the provided address with the passed fields set. Raises a t.Error if // any field does not match. @@ -292,7 +296,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPLeaveGroup, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, @@ -302,7 +306,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPMembershipQuery, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: true, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() }, @@ -312,7 +316,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv1MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, @@ -322,7 +326,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv2MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, @@ -332,7 +336,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv2MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{ - {Address: tcpip.Address("\x0a\x00\x0f\x01"), PrefixLen: 24}, + {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24}, {Address: stackAddr, PrefixLen: 24}, }, srcAddr: remoteAddr, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1a5661ca4..b11e56c6a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -150,6 +151,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if forwarding { + // There does not seem to be an RFC requirement for a node to join the all + // routers multicast address but + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + // specifies the address as a group for all routers on a subnet so we join + // the group here. + if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } + + return + } + + switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } +} + // Enable implements stack.NetworkEndpoint. func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() @@ -226,7 +259,7 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) @@ -318,7 +351,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -338,7 +371,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := e.addIPHeader(r.LocalAddress(), r.RemoteAddress(), pkt, params, nil /* options */); err != nil { return err } @@ -346,7 +379,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -369,10 +402,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } } - return e.writePacket(r, gso, pkt, false /* headerIncluded */) + return e.writePacket(r, pkt, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop()&stack.PacketLoop != 0 { // If the packet was generated by the stack (not a raw/packet endpoint // where a packet may be written with the header included), then we can @@ -383,6 +416,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + // Postrouting NAT can only change the source address, and does not alter the + // route or outgoing interface of the packet. + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesPostroutingDropped.Increment() + return nil + } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) @@ -391,20 +433,20 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return err } - if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if packetMustBeFragmented(pkt, networkMTU) { + sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. - return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + return e.nic.WritePacket(r, ProtocolNumber, fragPkt) }) stats.PacketsSent.IncrementBy(uint64(sent)) stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, ProtocolNumber, pkt); err != nil { stats.OutgoingPacketErrors.Increment() return err } @@ -413,7 +455,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("multiple packets in local loop") } @@ -434,11 +476,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return 0, err } - if packetMustBeFragmented(pkt, networkMTU, gso) { + if packetMustBeFragmented(pkt, networkMTU) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if _, _, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt @@ -454,9 +496,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) - stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - for pkt := range dropped { + outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) + for pkt := range outputDropped { pkts.Remove(pkt) } @@ -478,14 +520,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } + // We ignore the list of NAT-ed packets here because Postrouting NAT can only + // change the source address, and does not alter the route or outgoing + // interface of the packet. + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) + for pkt := range postroutingDropped { + pkts.Remove(pkt) + } + // The rest of the packets can be delivered to the NIC as a batch. pktsLen := pkts.Len() - written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + written, err := e.nic.WritePackets(r, pkts, ProtocolNumber) stats.PacketsSent.IncrementBy(uint64(written)) stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) // Dropped packets aren't errors, so include them in the return value. - return locallyDelivered + written + len(dropped), err + return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -545,12 +596,31 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return &tcpip.ErrMalformedHeader{} } - return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) + return e.writePacket(r, pkt, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv4(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + // As per RFC 3927 section 7, + // + // A router MUST NOT forward a packet with an IPv4 Link-Local source or + // destination address, irrespective of the router's default route + // configuration or routes obtained from dynamic routing protocols. + // + // A router which receives a packet with an IPv4 Link-Local source or + // destination address MUST NOT forward the packet. This prevents + // forwarding of packets back onto the network segment from which they + // originated, or to any other segment. + if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} + } + ttl := h.TTL() if ttl == 0 { // As per RFC 792 page 6, Time Exceeded Message, @@ -558,7 +628,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // If the gateway processing a datagram finds the time to live field // is zero it must discard the datagram. The gateway may also notify // the source host via the time exceeded message. - return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } if opts := h.Options(); len(opts) != 0 { @@ -569,10 +644,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { pointer: optProblem.Pointer, forwarding: true, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - e.stats.ip.MalformedPacketsReceived.Increment() } - return nil // option problems are not reported locally. + return &ip.ErrIPOptProblem{} } copied := copy(opts, newOpts) if copied != len(newOpts) { @@ -589,8 +662,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -598,8 +669,16 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() @@ -616,10 +695,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + })); err != nil { + return &ip.ErrOther{Err: err} + } + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -668,7 +750,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -734,14 +816,31 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) stats.ip.InvalidDestinationAddressesReceived.Increment() return } - _ = e.forwardPacket(pkt) + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrIPOptProblem: + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + stats.ip.Forwarding.Errors.Increment() return } // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -1114,28 +1213,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) return nil, false } - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { + if !h.IsChecksumValid() { return nil, false } @@ -1168,12 +1246,27 @@ func (p *protocol) Forwarding() bool { return uint8(atomic.LoadUint32(&p.forwarding)) == 1 } +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) + } + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) +} + // SetForwarding implements stack.ForwardingNetworkProtocol. func (p *protocol) SetForwarding(v bool) { - if v { - atomic.StoreUint32(&p.forwarding, 1) - } else { - atomic.StoreUint32(&p.forwarding, 0) + p.mu.Lock() + defer p.mu.Unlock() + + if !p.setForwarding(v) { + return + } + + for _, ep := range p.mu.eps { + ep.transitionForwarding(v) } } @@ -1200,9 +1293,9 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error return networkMTU - uint32(networkHeaderSize), nil } -func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool { payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() - return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU + return pkt.GSOOptions.Type == stack.GSONone && uint32(payload) > networkMTU } // addressToUint32 translates an IPv4 address into its little endian uint32 diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index eba91c68c..7a7cad04a 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -39,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -130,48 +131,69 @@ func TestForwarding(t *testing.T) { } remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + unreachableIPv4Addr := tcpip.Address(net.ParseIP("12.0.0.2").To4()) + multicastIPv4Addr := tcpip.Address(net.ParseIP("225.0.0.0").To4()) + linkLocalIPv4Addr := tcpip.Address(net.ParseIP("169.254.0.0").To4()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code + name string + TTL uint8 + sourceAddr tcpip.Address + destAddr tcpip.Address + expectErrorICMP bool + expectPacketForwarded bool + options header.IPv4Options + forwardedOptions header.IPv4Options + icmpType header.ICMPv4Type + icmpCode header.ICMPv4Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool }{ { name: "TTL of zero", TTL: 0, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, expectErrorICMP: true, icmpType: header.ICMPv4TimeExceeded, icmpCode: header.ICMPv4TTLExceeded, }, { - name: "TTL of one", - TTL: 1, - expectErrorICMP: false, + name: "TTL of one", + TTL: 1, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Max TTL", + TTL: math.MaxUint8, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, }, { - name: "four EOL options", - TTL: 2, - expectErrorICMP: false, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, + name: "four EOL options", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + options: header.IPv4Options{0, 0, 0, 0}, + forwardedOptions: header.IPv4Options{0, 0, 0, 0}, }, { - name: "TS type 1 full", - TTL: 2, + name: "TS type 1 full", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, options: header.IPv4Options{ 68, 12, 13, 0xF1, 192, 168, 1, 12, @@ -182,8 +204,10 @@ func TestForwarding(t *testing.T) { icmpCode: header.ICMPv4UnusedCode, }, { - name: "TS type 0", - TTL: 2, + name: "TS type 0", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, options: header.IPv4Options{ 68, 24, 21, 0x00, 1, 2, 3, 4, @@ -200,10 +224,13 @@ func TestForwarding(t *testing.T) { 13, 14, 15, 16, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, + expectPacketForwarded: true, }, { - name: "end of options list", - TTL: 2, + name: "end of options list", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, @@ -219,6 +246,37 @@ func TestForwarding(t *testing.T) { 0, 0, 0, // 7 bytes unknown option removed. 0, 0, 0, 0, }, + expectPacketForwarded: true, + }, + { + name: "Network unreachable", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: unreachableIPv4Addr, + expectErrorICMP: true, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4NetUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + destAddr: multicastIPv4Addr, + expectPacketUnrouteableError: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: linkLocalIPv4Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv4Addr, + destAddr: remoteIPv4Addr2, + expectLinkLocalSourceError: true, }, } for _, test := range tests { @@ -286,8 +344,8 @@ func TestForwarding(t *testing.T) { TotalLength: totalLen, Protocol: uint8(header.ICMPv4ProtocolNumber), TTL: test.TTL, - SrcAddr: remoteIPv4Addr1, - DstAddr: remoteIPv4Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, }) if len(test.options) != 0 { ip.SetHeaderLength(uint8(ipHeaderLength)) @@ -304,15 +362,15 @@ func TestForwarding(t *testing.T) { }) e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e1.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), checker.SrcAddr(ipv4Addr1.Address), - checker.DstAddr(remoteIPv4Addr1), + checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), @@ -325,15 +383,19 @@ func TestForwarding(t *testing.T) { if n := e2.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } - } else { - reply, ok := e2.Read() + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + reply, ok = e2.Read() + if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo packet through outgoing NIC") } checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv4Addr1), - checker.DstAddr(remoteIPv4Addr2), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), checker.IPv4Options(test.forwardedOptions), checker.ICMPv4( @@ -347,6 +409,39 @@ func TestForwarding(t *testing.T) { if n := e1.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } + } else if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want { + t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) } }) } @@ -1241,7 +1336,6 @@ type fragmentInfo struct { var fragmentationTests = []struct { description string mtu uint32 - gso *stack.GSO transportHeaderLength int payloadSize int wantFragments []fragmentInfo @@ -1249,7 +1343,6 @@ var fragmentationTests = []struct { { description: "No fragmentation", mtu: 1280, - gso: nil, transportHeaderLength: 0, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -1259,7 +1352,6 @@ var fragmentationTests = []struct { { description: "Fragmented", mtu: 1280, - gso: nil, transportHeaderLength: 0, payloadSize: 2000, wantFragments: []fragmentInfo{ @@ -1270,7 +1362,6 @@ var fragmentationTests = []struct { { description: "Fragmented with the minimum mtu", mtu: header.IPv4MinimumMTU, - gso: nil, transportHeaderLength: 0, payloadSize: 100, wantFragments: []fragmentInfo{ @@ -1282,7 +1373,6 @@ var fragmentationTests = []struct { { description: "Fragmented with mtu not a multiple of 8", mtu: header.IPv4MinimumMTU + 1, - gso: nil, transportHeaderLength: 0, payloadSize: 100, wantFragments: []fragmentInfo{ @@ -1294,7 +1384,6 @@ var fragmentationTests = []struct { { description: "No fragmentation with big header", mtu: 2000, - gso: nil, transportHeaderLength: 100, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -1302,20 +1391,8 @@ var fragmentationTests = []struct { }, }, { - description: "Fragmented with gso none", - mtu: 1280, - gso: &stack.GSO{Type: stack.GSONone}, - transportHeaderLength: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 144, more: false}, - }, - }, - { description: "Fragmented with big header", mtu: 1280, - gso: nil, transportHeaderLength: 100, payloadSize: 1200, wantFragments: []fragmentInfo{ @@ -1326,7 +1403,6 @@ var fragmentationTests = []struct { { description: "Fragmented with MTU smaller than header", mtu: 300, - gso: nil, transportHeaderLength: 1000, payloadSize: 500, wantFragments: []fragmentInfo{ @@ -1349,13 +1425,13 @@ func TestFragmentationWritePacket(t *testing.T) { r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Fatalf("r.WritePacket(_, _, _) = %s", err) + t.Fatalf("r.WritePacket(...): %s", err) } if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) @@ -1421,7 +1497,7 @@ func TestFragmentationWritePackets(t *testing.T) { r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -1528,7 +1604,7 @@ func TestFragmentationErrors(t *testing.T) { pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2612,34 +2688,36 @@ func TestWriteStats(t *testing.T) { const nPackets = 3 tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectOutputDropped int + expectPostroutingDropped int + expectWritten int }{ { name: "Accept all", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { name: "Accept all with error", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets - 1, }, { - name: "Drop all", + name: "Drop all with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] @@ -2648,16 +2726,32 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %s", err) } }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: nPackets, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { - name: "Drop some", + name: "Drop all with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: 0, + expectPostroutingDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule that matches only 1 // of the 3 packets. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. @@ -2670,10 +2764,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %s", err) } }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 1, + expectPostroutingDropped: 0, + expectWritten: nPackets, + }, { + name: "Drop some with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Postrouting DROP rule that matches only 1 + // of the 3 packets. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 1, + expectWritten: nPackets, }, } @@ -2687,7 +2804,7 @@ func TestWriteStats(t *testing.T) { writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil { return nWritten, err } nWritten++ @@ -2697,7 +2814,7 @@ func TestWriteStats(t *testing.T) { }, { name: "WritePackets", writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + return rt.WritePackets(pkts, stack.NetworkHeaderParams{}) }, }, } @@ -2724,13 +2841,16 @@ func TestWriteStats(t *testing.T) { nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { + t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) + } + if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { + t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) } if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) } }) } @@ -2995,12 +3115,14 @@ func TestCloseLocking(t *testing.T) { nicID1 = 1 nicID2 = 2 - src = tcpip.Address("\x10\x00\x00\x01") - dst = tcpip.Address("\x10\x00\x00\x02") - iterations = 1000 ) + var ( + src = tcptestutil.MustParse4("16.0.0.1") + dst = tcptestutil.MustParse4("16.0.0.2") + ) + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index bb9a02ed0..db998e83e 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -66,5 +66,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index a142b76c1..ebb0b73df 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe return } + // Keep needed information before trimming header. + p := hdr.TransportProtocol() + dstAddr := hdr.DestinationAddress() + // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6MinimumSize) if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // because they don't have the transport headers. return } + p = fragHdr.TransportProtocol() // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) } - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -273,7 +276,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if iph.HopLimit() != header.MLDHopLimit { return false } - if !header.IsV6LinkLocalAddress(iph.SourceAddress()) { + if !header.IsV6LinkLocalUnicastAddress(iph.SourceAddress()) { return false } return true @@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize) networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } + pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: @@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + code := header.ICMPv6(hdr).Code() + pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) + switch code { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -564,7 +568,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.dropped.Increment() return } @@ -704,7 +708,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, @@ -804,7 +808,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r routerAddr := srcAddr // Is the IP Source Address a link-local address? - if !header.IsV6LinkLocalAddress(routerAddr) { + if !header.IsV6LinkLocalUnicastAddress(routerAddr) { // ...No, silently drop the packet. received.invalid.Increment() return @@ -951,6 +955,7 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo // icmpReason is a marker interface for IPv6 specific ICMP errors. type icmpReason interface { isICMPReason() + isForwarding() bool } // icmpReasonParameterProblem is an error during processing of extension headers @@ -982,6 +987,9 @@ type icmpReasonParameterProblem struct { } func (*icmpReasonParameterProblem) isICMPReason() {} +func (*icmpReasonParameterProblem) isForwarding() bool { + return false +} // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. @@ -989,12 +997,44 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +func (*icmpReasonPortUnreachable) isForwarding() bool { + return false +} + +// icmpReasonNetUnreachable is an error where no route can be found to the +// network of the final destination. +type icmpReasonNetUnreachable struct{} + +func (*icmpReasonNetUnreachable) isICMPReason() {} + +func (*icmpReasonNetUnreachable) isForwarding() bool { + // If we hit a Network Unreachable error, then we also know we are + // operating as a router. As per RFC 4443 section 3.1: + // + // If the reason for the failure to deliver is lack of a matching + // entry in the forwarding node's routing table, the Code field is + // set to 0 (Network Unreachable). + return true +} + // icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in // transit to its final destination, as per RFC 4443 section 3.3. type icmpReasonHopLimitExceeded struct{} func (*icmpReasonHopLimitExceeded) isICMPReason() {} +func (*icmpReasonHopLimitExceeded) isForwarding() bool { + // If we hit a Hop Limit Exceeded error, then we know we are operating + // as a router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard + // the packet and originate an ICMPv6 Time Exceeded message with Code + // 0 to the source of the packet. This indicates either a routing + // loop or too small an initial Hop Limit value. + return true +} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -1002,6 +1042,10 @@ type icmpReasonReassemblyTimeout struct{} func (*icmpReasonReassemblyTimeout) isICMPReason() {} +func (*icmpReasonReassemblyTimeout) isForwarding() bool { + return false +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { @@ -1040,15 +1084,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return nil } - // If we hit a Hop Limit Exceeded error, then we know we are operating as a - // router. As per RFC 4443 section 3.3: - // - // If a router receives a packet with a Hop Limit of zero, or if a - // router decrements a packet's Hop Limit to zero, it MUST discard the - // packet and originate an ICMPv6 Time Exceeded message with Code 0 to - // the source of the packet. This indicates either a routing loop or - // too small an initial Hop Limit value. - // // If we are operating as a router, do not use the packet's destination // address as the response's source address as we should not own the // destination address of a packet we are forwarding. @@ -1058,7 +1093,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // packet as "multicast addresses must not be used as source addresses in IPv6 // packets", as per RFC 4291 section 2.7. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { + if reason.isForwarding() || isOrigDstMulticast { localAddr = "" } // Even if we were able to receive a packet from some remote, we may not have @@ -1147,6 +1182,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) + counter = sent.dstUnreachable case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) @@ -1167,7 +1206,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip PayloadLen: dataRange.Size(), })) if err := route.WritePacket( - nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: route.DefaultTTL(), diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 6a7705ed1..e457be3cf 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -81,7 +81,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { return nil } @@ -130,19 +130,19 @@ func (*testInterface) Spoofing() bool { return false } -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) +func (t *testInterface) WritePacket(r *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + return t.LinkEndpoint.WritePacket(r.Fields(), protocol, pkt) } -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) +func (t *testInterface) WritePackets(r *stack.Route, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return t.LinkEndpoint.WritePackets(r.Fields(), pkts, protocol) } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var r stack.RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) + return t.LinkEndpoint.WritePacket(r, protocol, pkt) } func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c6d9d8f0d..659057fa7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -314,7 +314,7 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { // Snooping switches MUST manage multicast forwarding state based on MLD // Report and Done messages sent with the unspecified address as the // IPv6 source address. - if header.IsV6LinkLocalAddress(addr) { + if header.IsV6LinkLocalUnicastAddress(addr) { e.mu.mld.sendQueuedReports() } } @@ -410,24 +410,52 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t // // Must only be called when the forwarding status changes. func (e *endpoint) transitionForwarding(forwarding bool) { + allRoutersGroups := [...]tcpip.Address{ + header.IPv6AllRoutersInterfaceLocalMulticastAddress, + header.IPv6AllRoutersLinkLocalMulticastAddress, + header.IPv6AllRoutersSiteLocalMulticastAddress, + } + e.mu.Lock() defer e.mu.Unlock() - if !e.Enabled() { - return - } - if forwarding { - // When transitioning into an IPv6 router, host-only state (NDP discovered - // routers, discovered on-link prefixes, and auto-generated addresses) is - // cleaned up/invalidated and NDP router solicitations are stopped. - e.mu.ndp.stopSolicitingRouters() - e.mu.ndp.cleanupState(true /* hostOnly */) + // As per RFC 4291 section 2.8: + // + // A router is required to recognize all addresses that a host is + // required to recognize, plus the following addresses as identifying + // itself: + // + // o The All-Routers multicast addresses defined in Section 2.7.1. + // + // As per RFC 4291 section 2.7.1, + // + // All Routers Addresses: FF01:0:0:0:0:0:0:2 + // FF02:0:0:0:0:0:0:2 + // FF05:0:0:0:0:0:0:2 + // + // The above multicast addresses identify the group of all IPv6 routers, + // within scope 1 (interface-local), 2 (link-local), or 5 (site-local). + for _, g := range allRoutersGroups { + if err := e.joinGroupLocked(g); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) + } + } } else { - // When transitioning into an IPv6 host, NDP router solicitations are - // started. - e.mu.ndp.startSolicitingRouters() + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } + } } + + e.mu.ndp.forwardingChanged(forwarding) } // Enable implements stack.NetworkEndpoint. @@ -509,17 +537,7 @@ func (e *endpoint) Enable() tcpip.Error { e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) } - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyway. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !e.protocol.Forwarding() { - e.mu.ndp.startSolicitingRouters() - } - + e.mu.ndp.startSolicitingRouters() return nil } @@ -570,10 +588,10 @@ func (e *endpoint) disableLocked() { return true }) - e.mu.ndp.cleanupState(false /* hostOnly */) + e.mu.ndp.cleanupState() // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) @@ -632,9 +650,9 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params return nil } -func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool { payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() - return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU + return pkt.GSOOptions.Type == stack.GSONone && uint32(payload) > networkMTU } // handleFragments fragments pkt and calls the handler function on each @@ -642,7 +660,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta // fragments left to be processed. The IP header must already be present in the // original packet. The transport header protocol number is required to avoid // parsing the IPv6 extension headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { networkHeader := header.IPv6(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are @@ -681,7 +699,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := addIPHeader(r.LocalAddress(), r.RemoteAddress(), pkt, params, nil /* extensionHeaders */); err != nil { return err } @@ -689,7 +707,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -712,10 +730,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } } - return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) + return e.writePacket(r, pkt, params.Protocol, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop()&stack.PacketLoop != 0 { // If the packet was generated by the stack (not a raw/packet endpoint // where a packet may be written with the header included), then we can @@ -726,6 +744,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + // Postrouting NAT can only change the source address, and does not alter the + // route or outgoing interface of the packet. + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesPostroutingDropped.Increment() + return nil + } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { @@ -733,20 +760,20 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return err } - if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if packetMustBeFragmented(pkt, networkMTU) { + sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. - return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + return e.nic.WritePacket(r, ProtocolNumber, fragPkt) }) stats.PacketsSent.IncrementBy(uint64(sent)) stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, ProtocolNumber, pkt); err != nil { stats.OutgoingPacketErrors.Increment() return err } @@ -756,7 +783,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("not implemented") } @@ -776,11 +803,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } - if packetMustBeFragmented(pb, networkMTU, gso) { + if packetMustBeFragmented(pb, networkMTU) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if _, _, err := e.handleFragments(r, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -797,9 +824,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) - stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - for pkt := range dropped { + outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) + for pkt := range outputDropped { pkts.Remove(pkt) } @@ -820,14 +847,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe locallyDelivered++ } + // We ignore the list of NAT-ed packets here because Postrouting NAT can only + // change the source address, and does not alter the route or outgoing + // interface of the packet. + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) + for pkt := range postroutingDropped { + pkts.Remove(pkt) + } + // The rest of the packets can be delivered to the NIC as a batch. pktsLen := pkts.Len() - written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + written, err := e.nic.WritePackets(r, pkts, ProtocolNumber) stats.PacketsSent.IncrementBy(uint64(written)) stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) // Dropped packets aren't errors, so include them in the return value. - return locallyDelivered + written + len(dropped), err + return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -863,12 +899,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return &tcpip.ErrMalformedHeader{} } - return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) + return e.writePacket(r, pkt, proto, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv6(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + // As per RFC 4291 section 2.5.6, + // + // Routers must not forward any packets with Link-Local source or + // destination addresses to other links. + if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} + } + hopLimit := h.HopLimit() if hopLimit <= 1 { // As per RFC 4443 section 3.3, @@ -878,11 +927,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // packet and originate an ICMPv6 Time Exceeded message with Code 0 to // the source of the packet. This indicates either a routing loop or // too small an initial Hop Limit value. - return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -890,8 +942,16 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() @@ -906,10 +966,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + })); err != nil { + return &ip.ErrOther{Err: err} + } + return nil } // HandlePacket is called by the link layer when new ipv6 packets arrive for @@ -958,7 +1021,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1010,8 +1073,21 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) stats.InvalidDestinationAddressesReceived.Increment() return } - - _ = e.forwardPacket(pkt) + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + e.stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + e.stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + e.stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + e.stats.ip.Forwarding.Unrouteable.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + e.stats.ip.Forwarding.Errors.Increment() return } @@ -1028,7 +1104,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1571,7 +1647,7 @@ func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { var linkLocalAddr tcpip.Address e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.IsAssigned(false /* allowExpired */) { - if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalUnicastAddress(addr) { linkLocalAddr = addr return false } @@ -1979,9 +2055,9 @@ func (p *protocol) Forwarding() bool { // Returns true if the forwarding status was updated. func (p *protocol) setForwarding(v bool) bool { if v { - return atomic.SwapUint32(&p.forwarding, 1) == 0 + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) } - return atomic.SwapUint32(&p.forwarding, 0) == 1 + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) } // SetForwarding implements stack.ForwardingNetworkProtocol. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index c206cebeb..4fbe39528 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) { func TestWriteStats(t *testing.T) { const nPackets = 3 tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectOutputDropped int + expectPostroutingDropped int + expectWritten int }{ { name: "Accept all", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { name: "Accept all with error", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets - 1, }, { - name: "Drop all", + name: "Drop all with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] @@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: nPackets, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { - name: "Drop some", + name: "Drop all with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: 0, + expectPostroutingDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule that matches only 1 // of the 3 packets. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) // We'll match and DROP the last packet. @@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 1, + expectPostroutingDropped: 0, + expectWritten: nPackets, + }, { + name: "Drop some with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Postrouting DROP rule that matches only 1 + // of the 3 packets. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 1, + expectWritten: nPackets, }, } @@ -2542,7 +2584,7 @@ func TestWriteStats(t *testing.T) { writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil { return nWritten, err } nWritten++ @@ -2552,7 +2594,7 @@ func TestWriteStats(t *testing.T) { }, { name: "WritePackets", writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + return rt.WritePackets(pkts, stack.NetworkHeaderParams{}) }, }, } @@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) { nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { + t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) + } + if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { + t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) } if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) } }) } @@ -2694,7 +2739,6 @@ type fragmentInfo struct { var fragmentationTests = []struct { description string mtu uint32 - gso *stack.GSO transHdrLen int payloadSize int wantFragments []fragmentInfo @@ -2702,7 +2746,6 @@ var fragmentationTests = []struct { { description: "No fragmentation", mtu: header.IPv6MinimumMTU, - gso: nil, transHdrLen: 0, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -2712,7 +2755,6 @@ var fragmentationTests = []struct { { description: "Fragmented", mtu: header.IPv6MinimumMTU, - gso: nil, transHdrLen: 0, payloadSize: 2000, wantFragments: []fragmentInfo{ @@ -2723,7 +2765,6 @@ var fragmentationTests = []struct { { description: "Fragmented with mtu not a multiple of 8", mtu: header.IPv6MinimumMTU + 1, - gso: nil, transHdrLen: 0, payloadSize: 2000, wantFragments: []fragmentInfo{ @@ -2734,7 +2775,6 @@ var fragmentationTests = []struct { { description: "No fragmentation with big header", mtu: 2000, - gso: nil, transHdrLen: 100, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -2742,20 +2782,8 @@ var fragmentationTests = []struct { }, }, { - description: "Fragmented with gso none", - mtu: header.IPv6MinimumMTU, - gso: &stack.GSO{Type: stack.GSONone}, - transHdrLen: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 176, more: false}, - }, - }, - { description: "Fragmented with big header", mtu: header.IPv6MinimumMTU, - gso: nil, transHdrLen: 100, payloadSize: 1200, wantFragments: []fragmentInfo{ @@ -2778,7 +2806,7 @@ func TestFragmentationWritePacket(t *testing.T) { source := pkt.Clone() ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2851,7 +2879,7 @@ func TestFragmentationWritePackets(t *testing.T) { r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2955,7 +2983,7 @@ func TestFragmentationErrors(t *testing.T) { pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2991,36 +3019,94 @@ func TestForwarding(t *testing.T) { } remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16()) + multicastIPv6Addr := tcpip.Address(net.ParseIP("ff00::").To16()) + linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool + name string + TTL uint8 + expectErrorICMP bool + expectPacketForwarded bool + countUnrouteablePackets uint64 + sourceAddr tcpip.Address + destAddr tcpip.Address + icmpType header.ICMPv6Type + icmpCode header.ICMPv6Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool }{ { name: "TTL of zero", TTL: 0, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { name: "TTL of one", TTL: 1, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, + }, + { + name: "TTL of two", + TTL: 2, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "TTL of three", + TTL: 3, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "Max TTL", + TTL: math.MaxUint8, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, }, { - name: "TTL of three", - TTL: 3, - expectErrorICMP: false, + name: "Network unreachable", + TTL: 2, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: unreachableIPv6Addr, + icmpType: header.ICMPv6DstUnreachable, + icmpCode: header.ICMPv6NetworkUnreachable, + expectPacketUnrouteableError: true, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Multicast destination", + TTL: 2, + countUnrouteablePackets: 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr, + expectPacketUnrouteableError: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: linkLocalIPv6Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv6Addr, + destAddr: remoteIPv6Addr2, + expectLinkLocalSourceError: true, }, } @@ -3073,35 +3159,35 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmp, - Src: remoteIPv6Addr1, - Dst: remoteIPv6Addr2, + Src: test.sourceAddr, + Dst: test.destAddr, })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, TransportProtocol: header.ICMPv6ProtocolNumber, HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) e1.InjectInbound(ProtocolNumber, requestPkt) + reply, ok := e1.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { - t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), checker.SrcAddr(ipv6Addr1.Address), - checker.DstAddr(remoteIPv6Addr1), + checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), + checker.ICMPv6Type(test.icmpType), + checker.ICMPv6Code(test.icmpCode), checker.ICMPv6Payload([]byte(hdr.View())), ), ) @@ -3109,15 +3195,19 @@ func TestForwarding(t *testing.T) { if n := e2.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } - } else { - reply, ok := e2.Read() + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + reply, ok = e2.Read() + if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv6Addr1), - checker.DstAddr(remoteIPv6Addr2), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6EchoRequest), @@ -3129,6 +3219,35 @@ func TestForwarding(t *testing.T) { if n := e1.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } + } else if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want { + t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) } }) } diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index dd153466d..bc1af193c 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -76,10 +76,29 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) // // Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { - _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2710 section 5 page 10, + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + // + // MLD messages are never sent for multicast addresses whose scope is 0 + // (reserved) or 1 (node-local). + if groupAddress == header.IPv6AllNodesMulticastAddress { + return false + } + + scope := header.V6MulticastScope(groupAddress) + return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope +} + // init sets up an mldState struct, and is required to be called before using // a new mldState. // @@ -91,7 +110,6 @@ func (mld *mldState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: mld, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv6AllNodesMulticastAddress, }) } @@ -259,7 +277,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp }, extensionHeaders); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } - if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), ProtocolNumber, pkt); err != nil { sentStats.dropped.Increment() return false, err } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 85a8f9944..71d1c3e28 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -27,15 +27,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var ( + linkLocalAddr = testutil.MustParse6("fe80::1") + globalAddr = testutil.MustParse6("a80::1") + globalMulticastAddr = testutil.MustParse6("ff05:100::2") + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) ) @@ -93,7 +92,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { if p, ok := e.Read(); !ok { t.Fatal("expected a done message to be sent") } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } } @@ -354,10 +353,8 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho } func TestMLDPacketValidation(t *testing.T) { - const ( - nicID = 1 - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ) + const nicID = 1 + linkLocalAddr2 := testutil.MustParse6("fe80::2") tests := []struct { name string @@ -464,3 +461,141 @@ func TestMLDPacketValidation(t *testing.T) { }) } } + +func TestMLDSkipProtocol(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + group tcpip.Address + expectReport bool + }{ + { + name: "Reserverd0", + group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Interface Local", + group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Link Local", + group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Realm Local", + group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Admin Local", + group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Site Local", + group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(6)", + group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(7)", + group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Organization Local", + group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(9)", + group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(A)", + group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(B)", + group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(C)", + group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(D)", + group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Global", + group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "ReservedF", + group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil { + t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err) + } + if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group) + } + + if !test.expectReport { + if p, ok := e.Read(); ok { + t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p) + } + + return + } + + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 536493f87..b29fed347 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -48,7 +48,7 @@ const ( // defaultHandleRAs is the default configuration for whether or not to // handle incoming Router Advertisements as a host. - defaultHandleRAs = true + defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router @@ -301,10 +301,60 @@ type NDPDispatcher interface { OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } +var _ fmt.Stringer = HandleRAsConfiguration(0) + +// HandleRAsConfiguration enumerates when RAs may be handled. +type HandleRAsConfiguration int + +const ( + // HandlingRAsDisabled indicates that Router Advertisements will not be + // handled. + HandlingRAsDisabled HandleRAsConfiguration = iota + + // HandlingRAsEnabledWhenForwardingDisabled indicates that router + // advertisements will only be handled when forwarding is disabled. + HandlingRAsEnabledWhenForwardingDisabled + + // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always + // be handled, even when forwarding is enabled. + HandlingRAsAlwaysEnabled +) + +// String implements fmt.Stringer. +func (c HandleRAsConfiguration) String() string { + switch c { + case HandlingRAsDisabled: + return "HandlingRAsDisabled" + case HandlingRAsEnabledWhenForwardingDisabled: + return "HandlingRAsEnabledWhenForwardingDisabled" + case HandlingRAsAlwaysEnabled: + return "HandlingRAsAlwaysEnabled" + default: + return fmt.Sprintf("HandleRAsConfiguration(%d)", c) + } +} + +// enabled returns true iff Router Advertisements may be handled given the +// specified forwarding status. +func (c HandleRAsConfiguration) enabled(forwarding bool) bool { + switch c { + case HandlingRAsDisabled: + return false + case HandlingRAsEnabledWhenForwardingDisabled: + return !forwarding + case HandlingRAsAlwaysEnabled: + return true + default: + panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c)) + } +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. + // + // Ignored unless configured to handle Router Advertisements. MaxRtrSolicitations uint8 // The amount of time between transmitting Router Solicitation messages. @@ -318,8 +368,9 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements are processed. - HandleRAs bool + // HandleRAs is the configuration for when Router Advertisements should be + // handled. + HandleRAs HandleRAsConfiguration // DiscoverDefaultRouters determines whether or not default routers are // discovered from Router Advertisements, as per RFC 4861 section 6. This @@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a protocol-wide configuration, so we check the // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding // packets. - if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { + if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment() return } @@ -737,7 +789,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { prefix := opt.Subnet() // Is the prefix a link-local? - if header.IsV6LinkLocalAddress(prefix.ID()) { + if header.IsV6LinkLocalUnicastAddress(prefix.ID()) { // ...Yes, skip as per RFC 4861 section 6.3.4, // and RFC 4862 section 5.5.3.b (for SLAAC). continue @@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t delete(tempAddrs, tempAddr) } -// removeSLAACAddresses removes all SLAAC addresses. -// -// If keepLinkLocal is false, the SLAAC generated link-local address is removed. -// -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { - linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - var linkLocalPrefixes int - for prefix, state := range ndp.slaacPrefixes { - // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them if we are cleaning up - // host-only state. - if keepLinkLocal && prefix == linkLocalSubnet { - linkLocalPrefixes++ - continue - } - - ndp.invalidateSLAACPrefix(prefix, state) - } - - if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { - panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) - } -} - // cleanupState cleans up ndp's state. // -// If hostOnly is true, then only host-specific state is cleaned up. -// // This function invalidates all discovered on-link prefixes, discovered // routers, and auto-generated addresses. // -// If hostOnly is true, then the link-local auto-generated address aren't -// invalidated as routers are also expected to generate a link-local address. -// // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { - ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) +func (ndp *ndpState) cleanupState() { + for prefix, state := range ndp.slaacPrefixes { + ndp.invalidateSLAACPrefix(prefix, state) + } for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // +// If ndp is not configured to handle Router Advertisements, routers will not +// be solicited as there is no point soliciting routers if we don't handle their +// advertisements. +// // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicitTimer.timer != nil { @@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() { return } + if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + return + } + // Calculate the random delay before sending our first RS, as per RFC // 4861 section 6.3.7. var delay time.Duration @@ -1703,7 +1735,7 @@ func (ndp *ndpState) startSolicitingRouters() { // the unspecified address if no address is assigned // to the sending interface. localAddr := header.IPv6Any - if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil { localAddr = addressEndpoint.AddressWithPrefix().Address addressEndpoint.DecRef() } @@ -1730,7 +1762,7 @@ func (ndp *ndpState) startSolicitingRouters() { icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpData, Src: localAddr, - Dst: header.IPv6AllRoutersMulticastAddress, + Dst: header.IPv6AllRoutersLinkLocalMulticastAddress, })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1739,14 +1771,14 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.stats.icmp.packetsSent - if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } - if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), ProtocolNumber, pkt); err != nil { sent.dropped.Increment() // Don't send any more messages if we had an error. remaining = 0 @@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() { } } +// forwardingChanged handles a change in forwarding configuration. +// +// If transitioning to a host, router solicitation will be started. Otherwise, +// router solicitation will be stopped if NDP is not configured to handle RAs +// as a router. +// +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) forwardingChanged(forwarding bool) { + if forwarding { + if ndp.configs.HandleRAs.enabled(forwarding) { + return + } + + ndp.stopSolicitingRouters() + return + } + + // Solicit routers when transitioning to a host. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if ndp.ep.Enabled() { + ndp.startSolicitingRouters() + } +} + // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // @@ -1839,7 +1897,7 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL } sent := e.stats.icmp.packetsSent - err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) + err := e.nic.WritePacketToRemote(remoteLinkAddr, ProtocolNumber, pkt) if err != nil { sent.dropped.Increment() } else { diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index c2758352f..2f18f60e8 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -29,6 +29,10 @@ type Stats struct { // ICMP holds ICMPv6 statistics. ICMP tcpip.ICMPv6Stats + + // UnhandledRouterAdvertisements is the number of Router Advertisements that + // were observed but not handled. + UnhandledRouterAdvertisements *tcpip.StatCounter } // IsNetworkEndpointStats implements stack.NetworkEndpointStats. diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index ecd5003a7..1b96b1fb8 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -30,22 +30,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") defaultIPv4PrefixLength = 24 - linkLocalIPv6Addr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalIPv6Addr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - - ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") - ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") - ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") - ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") - ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") igmpMembershipQuery = uint8(header.IGMPMembershipQuery) igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) @@ -59,6 +50,19 @@ const ( ) var ( + stackIPv4Addr = testutil.MustParse4("10.0.0.1") + linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1") + linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2") + + ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3") + ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4") + ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5") + ipv6MulticastAddr1 = testutil.MustParse6("ff02::3") + ipv6MulticastAddr2 = testutil.MustParse6("ff02::4") + ipv6MulticastAddr3 = testutil.MustParse6("ff02::5") +) + +var ( // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the // NIC will wait before sending an unsolicited report after joining a // multicast group, in deciseconds. @@ -194,7 +198,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") } else { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC) } // Should not send any more packets. @@ -606,7 +610,7 @@ func TestMGPLeaveGroup(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, checkInitialGroups: checkInitialIPv6Groups, }, @@ -1014,7 +1018,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr) }, getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { t.Helper() diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 210262703..b7f6d52ae 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -21,6 +21,7 @@ go_test( library = ":ports", deps = [ "//pkg/tcpip", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 678199371..b5b013b64 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -17,6 +17,7 @@ package ports import ( + "math" "math/rand" "sync/atomic" @@ -24,7 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -const anyIPAddress tcpip.Address = "" +const ( + firstEphemeral = 16000 + anyIPAddress tcpip.Address = "" +) // Reservation describes a port reservation. type Reservation struct { @@ -220,10 +224,8 @@ type PortManager struct { func NewPortManager() *PortManager { return &PortManager{ allocatedPorts: make(map[portDescriptor]addrToDevice), - // Match Linux's default ephemeral range. See: - // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 - firstEphemeral: 32768, - numEphemeral: 28232, + firstEphemeral: firstEphemeral, + numEphemeral: math.MaxUint16 - firstEphemeral + 1, } } @@ -242,13 +244,13 @@ func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err numEphemeral := pm.numEphemeral pm.ephemeralMu.RUnlock() - offset := uint16(rand.Int31n(int32(numEphemeral))) + offset := uint32(rand.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } // portHint atomically reads and returns the pm.hint value. -func (pm *PortManager) portHint() uint16 { - return uint16(atomic.LoadUint32(&pm.hint)) +func (pm *PortManager) portHint() uint32 { + return atomic.LoadUint32(&pm.hint) } // incPortHint atomically increments pm.hint by 1. @@ -260,7 +262,7 @@ func (pm *PortManager) incPortHint() { // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) { +func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) { pm.ephemeralMu.RLock() firstEphemeral := pm.firstEphemeral numEphemeral := pm.numEphemeral @@ -277,9 +279,9 @@ func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTeste // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { - for i := uint16(0); i < count; i++ { - port = first + (offset+i)%count +func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { + for i := uint32(0); i < uint32(count); i++ { + port := uint16(uint32(first) + (offset+i)%uint32(count)) ok, err := testPort(port) if err != nil { return 0, err diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 0f43dc8f8..6c4fb8c68 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -15,19 +15,23 @@ package ports import ( + "math" "math/rand" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 +) - fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") - fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") +var ( + fakeIPAddress = testutil.MustParse4("8.8.8.8") + fakeIPAddress1 = testutil.MustParse4("8.8.8.9") ) type portReserveTestAction struct { @@ -479,7 +483,7 @@ func TestPickEphemeralPortStable(t *testing.T) { if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { t.Fatalf("failed to set ephemeral port range: %s", err) } - portOffset := uint16(rand.Int31n(int32(numEphemeralPorts))) + portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) port, err := pm.PickEphemeralPortStable(portOffset, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) @@ -490,3 +494,29 @@ func TestPickEphemeralPortStable(t *testing.T) { }) } } + +// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a +// port allocation failure. +func TestOverflow(t *testing.T) { + // Use a small range and start at offsets that will cause an overflow. + count := uint16(50) + for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ { + reservedPorts := make(map[uint16]struct{}) + // Ensure we can reserve everything in the allowed range. + for i := uint16(0); i < count; i++ { + port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) { + if _, ok := reservedPorts[port]; !ok { + reservedPorts[port] = struct{}{} + return true, nil + } + return false, nil + }) + if err != nil { + t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts)) + } + if port < firstEphemeral || port > firstEphemeral+count { + t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1) + } + } + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index dc37e61a4..a6c877158 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -58,6 +58,9 @@ type SocketOptionsHandler interface { // changed. The handler is invoked with the new value for the socket send // buffer size. It also returns the newly set value. OnSetSendBufferSize(v int64) (newSz int64) + + // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE. + OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -99,6 +102,11 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) { return v } +// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize. +func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) { + return v +} + // StackHandler holds methods to access the stack options. These must be // implemented by the stack. type StackHandler interface { @@ -207,6 +215,14 @@ type SocketOptions struct { // sendBufferSize determines the send buffer size for this socket. sendBufferSize int64 + // getReceiveBufferLimits provides the handler to get the min, default and + // max size for receive buffer. It is initialized at the creation time and + // will not change. + getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` + + // receiveBufferSize determines the receive buffer size for this socket. + receiveBufferSize int64 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -217,10 +233,11 @@ type SocketOptions struct { // InitHandler initializes the handler. This must be called before using the // socket options utility. -func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) { +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits, getReceiveBufferLimits GetReceiveBufferLimits) { so.handler = handler so.stackHandler = stack so.getSendBufferLimits = getSendBufferLimits + so.getReceiveBufferLimits = getReceiveBufferLimits } func storeAtomicBool(addr *uint32, v bool) { @@ -632,3 +649,42 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { newSz := so.handler.OnSetSendBufferSize(v) atomic.StoreInt64(&so.sendBufferSize, newSz) } + +// GetReceiveBufferSize gets value for SO_RCVBUF option. +func (so *SocketOptions) GetReceiveBufferSize() int64 { + return atomic.LoadInt64(&so.receiveBufferSize) +} + +// SetReceiveBufferSize sets value for SO_RCVBUF option. +func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { + if !notify { + atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + return + } + + // Make sure the send buffer size is within the min and max + // allowed. + v := receiveBufferSize + ss := so.getReceiveBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + if v > max { + v = max + } + + // Multiply it by factor of 2. + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min + } + } else { + v = math.MaxInt32 + } + + oldSz := atomic.LoadInt64(&so.receiveBufferSize) + // Notify endpoint about change in buffer size. + newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) + atomic.StoreInt64(&so.receiveBufferSize, newSz) +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 49362333a..2bd6a67f5 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -45,6 +45,7 @@ go_library( "addressable_endpoint_state.go", "conntrack.go", "headertype_string.go", + "hook_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -66,6 +67,7 @@ go_library( "stack.go", "stack_global_state.go", "stack_options.go", + "tcp.go", "transport_demuxer.go", "tuple_list.go", ], @@ -115,6 +117,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -139,6 +142,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 3f083928f..5720e7543 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -16,6 +16,7 @@ package stack import ( "encoding/binary" + "fmt" "sync" "time" @@ -29,7 +30,7 @@ import ( // The connection is created for a packet if it does not exist. Every // connection contains two tuples (original and reply). The tuples are // manipulated if there is a matching NAT rule. The packet is modified by -// looking at the tuples in the Prerouting and Output hooks. +// looking at the tuples in each hook. // // Currently, only TCP tracking is supported. @@ -46,12 +47,14 @@ const ( ) // Manipulation type for the connection. +// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and +// DNAT at the same time. type manipType int const ( manipNone manipType = iota - manipDstPrerouting - manipDstOutput + manipSource + manipDestination ) // tuple holds a connection's identifying and manipulating data in one @@ -108,6 +111,7 @@ type conn struct { reply tuple // manip indicates if the packet should be manipulated. It is immutable. + // TODO(gvisor.dev/issue/5696): Support updating manipulation type. manip manipType // tcbHook indicates if the packet is inbound or outbound to @@ -124,6 +128,18 @@ type conn struct { lastUsed time.Time `state:".(unixTime)"` } +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn +} + // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour @@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { }, nil } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - func (ct *ConnTrack) init() { ct.mu.Lock() defer ct.mu.Unlock() @@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1 return nil } - // Create a new connection and change the port as per the iptables - // rule. This tuple will be used to manipulate the packet in - // handlePacket. replyTID := tid.reply() replyTID.srcAddr = address replyTID.srcPort = port - var manip manipType - switch hook { - case Prerouting: - manip = manipDstPrerouting - case Output: - manip = manipDstOutput + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil } - conn := newConn(tid, replyTID, manip, hook) + conn = newConn(tid, replyTID, manipDestination, hook) + ct.insertConn(conn) + return conn +} + +func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Input && hook != Postrouting { + return nil + } + + replyTID := tid.reply() + replyTID.dstAddr = address + replyTID.dstPort = port + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil + } + conn = newConn(tid, replyTID, manipSource, hook) ct.insertConn(conn) return conn } @@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Now that we hold the locks, ensure the tuple hasn't been inserted by // another thread. + // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? alreadyInserted := false for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { if other.tupleID == conn.original.tupleID { @@ -343,95 +369,17 @@ func (ct *ConnTrack) insertConn(conn *conn) { } } -// handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. -func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For prerouting redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. - switch dir { - case dirOriginal: - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - case dirReply: - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated - // on inbound packets, so we don't recalculate them. However, we should - // support cases when they are validated, e.g. when we can't offload - // receive checksumming. - - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - -// handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For output redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. For prerouting redirection, we only reach this point - // when replying, so packet sources are modified. - if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - } else { - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - if gso != nil && gso.NeedsCsum { - tcpHeader.SetChecksum(xsum) - } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) - } - - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - // handlePacket will manipulate the port and address of the packet if the // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if pkt.NatDone { return false } - if hook != Prerouting && hook != Output { + switch hook { + case Prerouting, Input, Output, Postrouting: + default: return false } @@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } conn, dir := ct.connFor(pkt) - // Connection or Rule not found for the packet. + // Connection not found for the packet. if conn == nil { - return true + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting } + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return false } + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. + + switch hook { + case Prerouting, Output: + if conn.manip == manipDestination { + switch dir { + case dirOriginal: + tcpHeader.SetDestinationPort(conn.reply.srcPort) + netHeader.SetDestinationAddress(conn.reply.srcAddr) + case dirReply: + tcpHeader.SetSourcePort(conn.original.dstPort) + netHeader.SetSourceAddress(conn.original.dstAddr) + } + pkt.NatDone = true + } + case Input, Postrouting: + if conn.manip == manipSource { + switch dir { + case dirOriginal: + tcpHeader.SetSourcePort(conn.reply.dstPort) + netHeader.SetSourceAddress(conn.reply.dstAddr) + case dirReply: + tcpHeader.SetDestinationPort(conn.original.srcPort) + netHeader.SetDestinationAddress(conn.original.srcAddr) + } + pkt.NatDone = true + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + if !pkt.NatDone { + return false + } + switch hook { - case Prerouting: - handlePacketPrerouting(pkt, conn, dir) - case Output: - handlePacketOutput(pkt, conn, gso, r, dir) + case Prerouting, Input: + case Output, Postrouting: + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + length := uint16(len(tcpHeader) + pkt.Data().Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) } - pkt.NatDone = true // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle @@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ if conn == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip == manipNone { - // Unmanipulated connection. + } else if conn.manip != manipDestination { + // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 16ee75bc4..7d3725681 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: vv.ToView().ToVectorisedView(), }) - // TODO(b/143425874) Decrease the TTL field in forwarded packets. + // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. _ = r.WriteHeaderIncludedPacket(pkt) } @@ -117,7 +117,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) @@ -125,11 +125,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress()[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, fwdTestNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -139,7 +139,7 @@ func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *Packet return &tcpip.ErrMalformedHeader{} } - return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -264,6 +264,8 @@ type fwdTestPacketInfo struct { Pkt *PacketBuffer } +var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) + type fwdTestLinkEndpoint struct { dispatcher NetworkDispatcher mtu uint32 @@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { return caps | CapabilityResolutionRequired } -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { @@ -322,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -338,10 +335,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, pkt) + e.WritePacket(r, protocol, pkt) n++ } diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go new file mode 100644 index 000000000..3dc8a7b02 --- /dev/null +++ b/pkg/tcpip/stack/hook_string.go @@ -0,0 +1,41 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type Hook ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Prerouting-0] + _ = x[Input-1] + _ = x[Forward-2] + _ = x[Output-3] + _ = x[Postrouting-4] + _ = x[NumHooks-5] +} + +const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks" + +var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47} + +func (i Hook) String() string { + if i >= Hook(len(_Hook_index)-1) { + return "Hook(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Hook_name[_Hook_index[i]:_Hook_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 52890f6eb..e2894c548 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -175,9 +175,10 @@ func DefaultTables() *IPTables { }, }, priorities: [NumHooks][]TableID{ - Prerouting: {MangleID, NATID}, - Input: {NATID, FilterID}, - Output: {MangleID, NATID, FilterID}, + Prerouting: {MangleID, NATID}, + Input: {NATID, FilterID}, + Output: {MangleID, NATID, FilterID}, + Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -266,12 +267,12 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from // which address can be gathered. Currently, address is only needed for // prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -285,7 +286,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer // Packets are manipulated only if connection and matching // NAT rule exists. - shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, r) // Go through each table containing the hook. priorities := it.priorities[hook] @@ -302,7 +303,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -313,7 +314,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -385,10 +386,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { + if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +409,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +430,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,7 +456,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -478,7 +479,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) + return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 0e8b90c9b..2812c89aa 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -103,7 +103,7 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -174,7 +174,85 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // packet of the connection comes here. Other packets will be // manipulated in connection tracking. if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, gso, r) + ct.handlePacket(pkt, hook, r) + } + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} + +// SNATTarget modifies the source port/IP in the outgoing packets. +type SNATTarget struct { + Addr tcpip.Address + Port uint16 + + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } + + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { + return RuleDrop, 0 + } + + switch hook { + case Postrouting, Input: + case Prerouting, Output, Forward: + panic(fmt.Sprintf("%s not supported", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + udpHeader := header.UDP(pkt.TransportHeader().View()) + udpHeader.SetChecksum(0) + udpHeader.SetSourcePort(st.Port) + netHeader := pkt.Network() + netHeader.SetSourceAddress(st.Addr) + + // Only calculate the checksum if offloading isn't supported. + if r.RequiresTXTransportChecksum() { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } + pkt.NatDone = true + case header.TCPProtocolNumber: + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index b0d84befb..4631ab93f 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -345,5 +345,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 14124ae66..c585b81b2 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -33,15 +33,19 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) +var ( + addr1 = testutil.MustParse6("a00::1") + addr2 = testutil.MustParse6("a00::2") + addr3 = testutil.MustParse6("a00::3") +) + const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") @@ -1142,57 +1146,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) +func TestDynamicConfigurationsDisabled(t *testing.T) { + const ( + nicID = 1 + maxRtrSolicitDelay = time.Second + ) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + prefix := tcpip.AddressWithPrefix{ + Address: testutil.MustParse6("102:304:506:708::"), + PrefixLen: 64, + } - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: + tests := []struct { + name string + config func(bool) ipv6.NDPConfigurations + ra *stack.PacketBuffer + }{ + { + name: "No Router Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} + }, + ra: raBuf(llAddr2, 1000), + }, + { + name: "No Prefix Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), + }, + { + name: "No Autogenerate Addresses", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Being configured to discover routers/prefixes or auto-generate + // addresses means RAs must be handled, and router/prefix discovery or + // SLAAC must be enabled. + // + // This tests all possible combinations of the configurations where + // router/prefix discovery or SLAAC are disabled. + for i := 0; i < 7; i++ { + handle := ipv6.HandlingRAsDisabled + if i&1 != 0 { + handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled + } + enable := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + ndpConfigs := test.config(enable) + ndpConfigs.HandleRAs = handle + ndpConfigs.MaxRtrSolicitations = 1 + ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay + ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, + }) + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding + ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) + } + stats := ep.Stats() + v6Stats, ok := stats.(*ipv6.Stats) + if !ok { + t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) + } + + // Make sure that when handling RAs are enabled, we solicit routers. + clock.Advance(maxRtrSolicitDelay) + if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { + t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) + } + if handleRAsDisabled { + if p, ok := e.Read(); ok { + t.Errorf("unexpectedly got a packet = %#v", p) + } + } else if p, ok := e.Read(); !ok { + t.Error("expected router solicitation packet") + } else if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } else { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(nil)), + ) + } + + // Make sure we do not discover any routers or prefixes, or perform + // SLAAC on reception of an RA. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) + // Make sure that the unhandled RA stat is only incremented when + // handling RAs is disabled. + if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { + t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) + default: + } + }) } }) } } +func boolToUint64(v bool) uint64 { + if v { + return 1 + } + return 0 +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // discovered flag set to discovered. func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) } +func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { + tests := [...]struct { + name string + handleRAs ipv6.HandleRAsConfiguration + forwarding bool + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding disabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding enabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f(t, test.handleRAs, test.forwarding) + }) + } +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1203,7 +1348,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1237,103 +1382,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") } - default: - t.Fatal("expected router discovery event") } - } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from another router (lladdr3) with non-zero lifetime. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + // Rx an RA from lladdr2 with lesser lifetime. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + }) } // TestRouterDiscoveryMaxRouters tests that only @@ -1347,7 +1498,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1386,57 +1537,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } } -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for prefix on nic with ID 1, and the // discovered flag set to discovered. func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { @@ -1455,8 +1555,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1494,87 +1593,93 @@ func TestPrefixDiscovery(t *testing.T) { prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") } - default: - t.Fatal("expected prefix discovery event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) + // Wait for prefix2's most recent invalidation job plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) + }) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1590,7 +1695,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }() prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, } subnet := prefix.Subnet() @@ -1603,7 +1708,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1688,7 +1793,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: false, DiscoverOnLinkPrefixes: true, }, @@ -1753,53 +1858,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) return containsAddr(list, protocolAddress) } -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for addr on nic with ID 1, and the // event type is set to eventType. func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { @@ -1808,7 +1866,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr2(t *testing.T) { +func TestAutoGenAddr(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second saved := ipv6.MinPrefixInformationValidLifetimeForUpdate @@ -1820,96 +1878,102 @@ func TestAutoGenAddr2(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - default: - t.Fatal("expected addr auto gen event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + }) } func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { @@ -1997,7 +2061,7 @@ func TestAutoGenTempAddr(t *testing.T) { RetransmitTimer: test.retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2298,7 +2362,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2385,7 +2449,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2534,7 +2598,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2735,7 +2799,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, @@ -2880,7 +2944,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: ndpDisp, @@ -3347,7 +3411,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3490,7 +3554,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3557,7 +3621,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3723,7 +3787,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3805,7 +3869,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3969,7 +4033,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { @@ -3996,7 +4060,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -4146,7 +4210,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4274,7 +4338,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4480,7 +4544,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4531,7 +4595,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4625,8 +4689,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } } -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. +func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { + const ( + lifetimeSeconds = 999 + nicID = 1 + ) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenLinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) + + e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) + if err := s.CreateNIC(nicID, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) + } + + prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) + e1.InjectInbound( + header.IPv6ProtocolNumber, + raBufWithPI( + llAddr3, + lifetimeSeconds, + prefix, + true, /* onLink */ + true, /* auto */ + lifetimeSeconds, + lifetimeSeconds, + ), + ) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + } + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) + } + + // Enabling or disabling forwarding should not invalidate discovered prefixes + // or routers, or auto-generated address. + for _, forwarding := range [...]bool{true, false} { + t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpected router event = %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpected prefix event = %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpected auto-gen addr event = %#v", e) + default: + } + }) + } +} + func TestCleanupNDPState(t *testing.T) { const ( lifetimeSeconds = 5 @@ -4655,18 +4821,6 @@ func TestCleanupNDPState(t *testing.T) { maxAutoGenAddrEvents int skipFinalAddrCheck bool }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - // A NIC should cleanup all NDP state when it is disabled. { name: "Disable NIC", @@ -4718,7 +4872,7 @@ func TestCleanupNDPState(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, DiscoverOnLinkPrefixes: true, AutoGenGlobalAddresses: true, @@ -4991,7 +5145,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -5182,96 +5336,127 @@ func TestRouterSolicitation(t *testing.T) { }, } + subTests := []struct { + name string + handleRAs ipv6.HandleRAsConfiguration + afterFirstRS func(*testing.T, *stack.Stack) + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + afterFirstRS: func(*testing.T, *stack.Stack) {}, + }, + + // Enabling forwarding when RAs are always configured to be handled + // should not stop router solicitations. + { + name: "Handle RAs always", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + afterFirstRS: func(t *testing.T, s *stack.Stack) { + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") + } - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: subTest.handleRAs, + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } + subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) + } else { + waitForPkt(test.effectiveRtrSolicitInt) + } + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) } }) } @@ -5362,13 +5547,14 @@ func TestStopStartSolicitingRouters(t *testing.T) { } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, MaxRtrSolicitationDelay: delay, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 48bb75e2f..9821a18d3 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() - clock := &tcpip.StdClock{} + clock := tcpip.NewStdClock() linkRes := newTestNeighborResolver(nil, config, clock) linkRes.delay = 0 diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index bb2b2d705..1d39ee73d 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -26,14 +26,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 entryTestNICID tcpip.NICID = 1 - entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") @@ -44,6 +43,11 @@ const ( entryTestNetDefaultMTU = 65536 ) +var ( + entryTestAddr1 = testutil.MustParse6("a::1") + entryTestAddr2 = testutil.MustParse6("a::2") +) + // runImmediatelyScheduledJobs runs all jobs scheduled to run at the current // time. func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ca15c0691..8d615500f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -316,30 +316,30 @@ func (n *nic) IsLoopback() bool { } // WritePacket implements NetworkLinkEndpoint. -func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) +func (n *nic) WritePacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { + _, err := n.enqueuePacketBuffer(r, protocol, pkt) return err } -func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) writePacketBuffer(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { switch pkt := pkt.(type) { case *PacketBuffer: - if err := n.writePacket(r, gso, protocol, pkt); err != nil { + if err := n.writePacket(r, protocol, pkt); err != nil { return 0, err } return 1, nil case *PacketBufferList: - return n.writePackets(r, gso, protocol, *pkt) + return n.writePackets(r, protocol, *pkt) default: panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } } -func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) enqueuePacketBuffer(r *Route, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { routeInfo, _, err := r.resolvedFields(nil) switch err.(type) { case nil: - return n.writePacketBuffer(routeInfo, gso, protocol, pkt) + return n.writePacketBuffer(routeInfo, protocol, pkt) case *tcpip.ErrWouldBlock: // As per relevant RFCs, we should queue packets while we wait for link // resolution to complete. @@ -358,28 +358,27 @@ func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt // SHOULD be limited to some small value. When a queue overflows, the new // arrival SHOULD replace the oldest entry. Once address resolution // completes, the node transmits any queued packets. - return n.linkResQueue.enqueue(r, gso, protocol, pkt) + return n.linkResQueue.enqueue(r, protocol, pkt) default: return 0, err } } // WritePacketToRemote implements NetworkInterface. -func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr - return n.writePacket(r, gso, protocol, pkt) + return n.writePacket(r, protocol, pkt) } -func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() pkt.EgressRoute = r - pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol - if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { + if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil { return err } @@ -389,18 +388,17 @@ func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return n.enqueuePacketBuffer(r, gso, protocol, &pkts) +func (n *nic) WritePackets(r *Route, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return n.enqueuePacketBuffer(r, protocol, &pkts) } -func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { +func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r - pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol } - writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) + writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index c0f956e53..8a3005295 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -65,12 +65,12 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { } // WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { +func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error { return nil } // WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { +func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { // Our tests don't use this so we don't support it. return 0, &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 8f288675d..9527416cf 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -103,7 +103,7 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. EgressRoute RouteInfo - GSOOptions *GSO + GSOOptions GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. @@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - return NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), }) + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + if pk.NatDone { + newPk.NatDone = true + } + return newPk } // headerInfo stores metadata about a header in a packet. @@ -355,9 +364,10 @@ func (d PacketData) PullUp(size int) (buffer.View, bool) { return d.pk.data.PullUp(size) } -// TrimFront removes count from the beginning of d. It panics if count > -// d.Size(). -func (d PacketData) TrimFront(count int) { +// DeleteFront removes count from the beginning of d. It panics if count > +// d.Size(). All backing storage references after the front of the d are +// invalidated. +func (d PacketData) DeleteFront(count int) { d.pk.data.TrimFront(count) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 6728370c3..bd4eb4fed 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkData(t, pk, test.data) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), - concatViews(test.link, test.network, test.transport, test.data)) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(test.link, test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(test.transport, test.data)) + // Check the after state. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.link, + network: test.network, + transport: test.transport, + data: test.data, + }) }) } } @@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) { if got, want := pk.Size(), len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - // After state of pk. - var ( - link = test.data[:test.link] - network = test.data[test.link:][:test.network] - transport = test.data[test.link+test.network:][:test.transport] - payload = test.data[allHdrSize:] - ) - checkData(t, pk, payload) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(link, network, transport, payload)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(network, transport, payload)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(transport, payload)) + // Check the after state of pk. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.data[:test.link], + network: test.data[test.link:][:test.network], + transport: test.data[test.link+test.network:][:test.transport], + data: test.data[allHdrSize:], + }) }) } } @@ -252,6 +226,39 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) { }) } +// This is a very obscure use-case seen in the code that verifies packets +// before sending them out. It tries to parse the headers to verify. +// PacketHeader was initially not designed to mix Push() and Consume(), but it +// works and it's been relied upon. Include a test here. +func TestPacketHeaderPushConsumeMixed(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := append([]byte(nil), network...) + initData = append(initData, data...) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Consume network header + gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) + if !ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) + } + checkViewEqual(t, "gotNetwork", gotNetwork, network) + + // 2. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + network: network, + data: data, + }) +} + func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 @@ -397,11 +404,11 @@ func TestPacketBufferData(t *testing.T) { } }) - // TrimFront + // DeleteFront for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().TrimFront(n) + pkt.Data().DeleteFront(n) checkData(t, pkt, []byte(tc.data)[n:]) }) @@ -494,6 +501,37 @@ func TestPacketBufferData(t *testing.T) { } } +type packetContents struct { + link buffer.View + network buffer.View + transport buffer.View + data buffer.View +} + +func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { + t.Helper() + // Headers. + checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) + checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) + checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) + // Data. + checkData(t, pk, want.data) + // Whole packet. + checkViewEqual(t, prefix+"pk.Views()", + concatViews(pk.Views()...), + concatViews(want.link, want.network, want.transport, want.data)) + // PayloadSince. + checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(want.link, want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(want.transport, want.data)) +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -510,19 +548,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkData(t, pk, data) - checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) - // Check the initial values for each header. - checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) - checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) - checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) - // Check the initial valies for PayloadSince. - checkViewEqual(t, "Initial PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), data) + checkPacketContents(t, "Initial ", pk, packetContents{ + data: data, + }) } func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index e936aa728..13e8907ec 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -46,7 +46,6 @@ func (p *PacketBufferList) len() int { type pendingPacket struct { routeInfo RouteInfo - gso *GSO proto tcpip.NetworkProtocolNumber pkt pendingPacketBuffer } @@ -119,7 +118,7 @@ func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpi // If the maximum number of pending resolutions is reached, the packets // associated with the oldest link resolution will be dequeued as if they failed // link resolution. -func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { f.mu.Lock() // Make sure we attempt resolution while holding f's lock so that we avoid // a race where link resolution completes before we enqueue the packets. @@ -137,7 +136,7 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N // The route resolved immediately, so we don't need to wait for link // resolution to send the packet. f.mu.Unlock() - return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt) + return f.nic.writePacketBuffer(routeInfo, proto, pkt) case *tcpip.ErrWouldBlock: // We need to wait for link resolution to complete. default: @@ -150,7 +149,6 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N packets, ok := f.mu.packets[ch] packets = append(packets, pendingPacket{ routeInfo: routeInfo, - gso: gso, proto: proto, pkt: pkt, }) @@ -211,7 +209,7 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l for _, p := range packets { if err == nil { p.routeInfo.RemoteLinkAddress = linkAddr - _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + _, _ = f.nic.writePacketBuffer(p.routeInfo, p.proto, p.pkt) } else { f.incrementOutgoingPacketErrors(p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ff3a385e1..e26225552 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -537,14 +537,14 @@ type NetworkInterface interface { CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool // WritePacketToRemote writes the packet to the given remote link address. - WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePacket writes a packet with the given protocol through the given // route. // // WritePacket takes ownership of the packet buffer. The packet buffer's // network and transport header must be set. - WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. @@ -554,7 +554,7 @@ type NetworkInterface interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + WritePackets(*Route, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) // HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP // request or NDP Neighbor Solicitation). @@ -610,12 +610,12 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error + WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and // underlying packets. - WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) + WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. @@ -756,11 +756,6 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityHardwareGSO - - // CapabilitySoftwareGSO indicates the link endpoint supports of sending - // multiple packets using a single call (LinkEndpoint.WritePackets). - CapabilitySoftwareGSO ) // NetworkLinkEndpoint is a data-link layer that supports sending network @@ -832,7 +827,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacket(RouteInfo, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. @@ -842,7 +837,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -1047,10 +1042,29 @@ type GSO struct { MaxSize uint32 } +// SupportedGSO returns the type of segmentation offloading supported. +type SupportedGSO int + +const ( + // GSONotSupported indicates that segmentation offloading is not supported. + GSONotSupported SupportedGSO = iota + + // HWGSOSupported indicates that segmentation offloading may be performed by + // the hardware. + HWGSOSupported + + // SWGSOSupported indicates that segmentation offloading may be performed in + // software. + SWGSOSupported +) + // GSOEndpoint provides access to GSO properties. type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 + + // SupportedGSO returns the supported segmentation offloading. + SupportedGSO() SupportedGSO } // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 39344808d..8a044c073 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp localAddr = addressEndpoint.AddressWithPrefix().Address } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) { addressEndpoint.DecRef() return nil } @@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool { // HasSoftwareGSOCapability returns true if the route supports software GSO. func (r *Route) HasSoftwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == SWGSOSupported + } + return false } // HasHardwareGSOCapability returns true if the route supports hardware GSO. func (r *Route) HasHardwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == HWGSOSupported + } + return false } // HasSaveRestoreCapability returns true if the route supports save/restore. @@ -448,22 +454,22 @@ func (r *Route) isValidForOutgoingRLocked() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { +func (r *Route) WritePacket(params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { return &tcpip.ErrInvalidEndpointState{} } - return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. -func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (r *Route) WritePackets(pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { if !r.isValidForOutgoing() { return 0, &tcpip.ErrInvalidEndpointState{} } - return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 931a97ddc..436392f23 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -35,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" ) @@ -56,306 +55,6 @@ type transportProtocolState struct { defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } -// TCPProbeFunc is the expected function type for a TCP probe function to be -// passed to stack.AddTCPProbe. -type TCPProbeFunc func(s TCPEndpointState) - -// TCPCubicState is used to hold a copy of the internal cubic state when the -// TCPProbeFunc is invoked. -type TCPCubicState struct { - WLastMax float64 - WMax float64 - T time.Time - TimeSinceLastCongestion time.Duration - C float64 - K float64 - Beta float64 - WC float64 - WEst float64 -} - -// TCPRACKState is used to hold a copy of the internal RACK state when the -// TCPProbeFunc is invoked. -type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool - ReoWnd time.Duration - ReoWndIncr uint8 - ReoWndPersist int8 - RTTSeq seqnum.Value -} - -// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. -type TCPEndpointID struct { - // LocalPort is the local port associated with the endpoint. - LocalPort uint16 - - // LocalAddress is the local [network layer] address associated with - // the endpoint. - LocalAddress tcpip.Address - - // RemotePort is the remote port associated with the endpoint. - RemotePort uint16 - - // RemoteAddress it the remote [network layer] address associated with - // the endpoint. - RemoteAddress tcpip.Address -} - -// TCPFastRecoveryState holds a copy of the internal fast recovery state of a -// TCP endpoint. -type TCPFastRecoveryState struct { - // Active if true indicates the endpoint is in fast recovery. - Active bool - - // First is the first unacknowledged sequence number being recovered. - First seqnum.Value - - // Last is the 'recover' sequence number that indicates the point at - // which we should exit recovery barring any timeouts etc. - Last seqnum.Value - - // MaxCwnd is the maximum value we are permitted to grow the congestion - // window during recovery. This is set at the time we enter recovery. - MaxCwnd int - - // HighRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - HighRxt seqnum.Value - - // RescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - RescueRxt seqnum.Value -} - -// TCPReceiverState holds a copy of the internal state of the receiver for -// a given TCP endpoint. -type TCPReceiverState struct { - // RcvNxt is the TCP variable RCV.NXT. - RcvNxt seqnum.Value - - // RcvAcc is the TCP variable RCV.ACC. - RcvAcc seqnum.Value - - // RcvWndScale is the window scaling to use for inbound segments. - RcvWndScale uint8 - - // PendingBufUsed is the number of bytes pending in the receive - // queue. - PendingBufUsed int -} - -// TCPSenderState holds a copy of the internal state of the sender for -// a given TCP Endpoint. -type TCPSenderState struct { - // LastSendTime is the time at which we sent the last segment. - LastSendTime time.Time - - // DupAckCount is the number of Duplicate ACK's received. - DupAckCount int - - // SndCwnd is the size of the sending congestion window in packets. - SndCwnd int - - // Ssthresh is the slow start threshold in packets. - Ssthresh int - - // SndCAAckCount is the number of packets consumed in congestion - // avoidance mode. - SndCAAckCount int - - // Outstanding is the number of packets in flight. - Outstanding int - - // SackedOut is the number of packets which have been selectively acked. - SackedOut int - - // SndWnd is the send window size in bytes. - SndWnd seqnum.Size - - // SndUna is the next unacknowledged sequence number. - SndUna seqnum.Value - - // SndNxt is the sequence number of the next segment to be sent. - SndNxt seqnum.Value - - // RTTMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - RTTMeasureSeqNum seqnum.Value - - // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time - - // Closed indicates that the caller has closed the endpoint for sending. - Closed bool - - // SRTT is the smoothed round-trip time as defined in section 2 of - // RFC 6298. - SRTT time.Duration - - // RTO is the retransmit timeout as defined in section of 2 of RFC 6298. - RTO time.Duration - - // RTTVar is the round-trip time variation as defined in section 2 of - // RFC 6298. - RTTVar time.Duration - - // SRTTInited if true indicates take a valid RTT measurement has been - // completed. - SRTTInited bool - - // MaxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - MaxPayloadSize int - - // SndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - SndWndScale uint8 - - // MaxSentAck is the highest acknowledgement number sent till now. - MaxSentAck seqnum.Value - - // FastRecovery holds the fast recovery state for the endpoint. - FastRecovery TCPFastRecoveryState - - // Cubic holds the state related to CUBIC congestion control. - Cubic TCPCubicState - - // RACKState holds the state related to RACK loss detection algorithm. - RACKState TCPRACKState -} - -// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. -type TCPSACKInfo struct { - // Blocks is the list of SACK Blocks that identify the out of order segments - // held by a given TCP endpoint. - Blocks []header.SACKBlock - - // ReceivedBlocks are the SACK blocks received by this endpoint - // from the peer endpoint. - ReceivedBlocks []header.SACKBlock - - // MaxSACKED is the highest sequence number that has been SACKED - // by the peer. - MaxSACKED seqnum.Value -} - -// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. -type RcvBufAutoTuneParams struct { - // MeasureTime is the time at which the current measurement - // was started. - MeasureTime time.Time - - // CopiedBytes is the number of bytes copied to user space since - // this measure began. - CopiedBytes int - - // PrevCopiedBytes is the number of bytes copied to userspace in - // the previous RTT period. - PrevCopiedBytes int - - // RcvBufSize is the auto tuned receive buffer size. - RcvBufSize int - - // RTT is the smoothed RTT as measured by observing the time between - // when a byte is first acknowledged and the receipt of data that is at - // least one window beyond the sequence number that was acknowledged. - RTT time.Duration - - // RTTVar is the "round-trip time variation" as defined in section 2 - // of RFC6298. - RTTVar time.Duration - - // RTTMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - RTTMeasureSeqNumber seqnum.Value - - // RTTMeasureTime is the absolute time at which the current RTT - // measurement period began. - RTTMeasureTime time.Time - - // Disabled is true if an explicit receive buffer is set for the - // endpoint. - Disabled bool -} - -// TCPEndpointState is a copy of the internal state of a TCP endpoint. -type TCPEndpointState struct { - // ID is a copy of the TransportEndpointID for the endpoint. - ID TCPEndpointID - - // SegTime denotes the absolute time when this segment was received. - SegTime time.Time - - // RcvBufSize is the size of the receive socket buffer for the endpoint. - RcvBufSize int - - // RcvBufUsed is the amount of bytes actually held in the receive socket - // buffer for the endpoint. - RcvBufUsed int - - // RcvBufAutoTuneParams is used to hold state variables to compute - // the auto tuned receive buffer size. - RcvAutoParams RcvBufAutoTuneParams - - // RcvClosed if true, indicates the endpoint has been closed for reading. - RcvClosed bool - - // SendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1. - SendTSOk bool - - // RecentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - RecentTS uint32 - - // TSOffset is a randomized offset added to the value of the TSVal field - // in the timestamp option. - TSOffset uint32 - - // SACKPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - SACKPermitted bool - - // SACK holds TCP SACK related information for this endpoint. - SACK TCPSACKInfo - - // SndBufSize is the size of the socket send buffer. - SndBufSize int - - // SndBufUsed is the number of bytes held in the socket send buffer. - SndBufUsed int - - // SndClosed indicates that the endpoint has been closed for sends. - SndClosed bool - - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - - // PacketTooBigCount is used to notify the main protocol routine how - // many times a "packet too big" control packet is received. - PacketTooBigCount int - - // SndMTU is the smallest MTU seen in the control packets received. - SndMTU int - - // Receiver holds variables related to the TCP receiver for the endpoint. - Receiver TCPReceiverState - - // Sender holds state related to the TCP Sender for the endpoint. - Sender TCPSenderState -} - // ResumableEndpoint is an endpoint that needs to be resumed after restore. type ResumableEndpoint interface { // Resume resumes an endpoint after restore. This can be used to restart @@ -455,7 +154,7 @@ type Stack struct { // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. - receiveBufferSize ReceiveBufferSizeOption + receiveBufferSize tcpip.ReceiveBufferSizeOption // tcpInvalidRateLimit is the maximal rate for sending duplicate // acknowledgements in response to incoming TCP packets that are for an existing @@ -623,7 +322,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {} func New(opts Options) *Stack { clock := opts.Clock if clock == nil { - clock = &tcpip.StdClock{} + clock = tcpip.NewStdClock() } if opts.UniqueID == nil { @@ -669,7 +368,7 @@ func New(opts Options) *Stack { Default: DefaultBufferSize, Max: DefaultMaxBufferSize, }, - receiveBufferSize: ReceiveBufferSizeOption{ + receiveBufferSize: tcpip.ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -1344,7 +1043,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) + isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) @@ -1381,7 +1080,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1874,7 +1573,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) - return nic.WritePacketToRemote(remote, nil, netProto, pkt) + return nic.WritePacketToRemote(remote, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index dfec4258a..33824afd0 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,6 +14,78 @@ package stack +import "time" + // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack + +// saveT is invoked by stateify. +func (t *TCPCubicState) saveT() unixTime { + return unixTime{t.T.Unix(), t.T.UnixNano()} +} + +// loadT is invoked by stateify. +func (t *TCPCubicState) loadT(unix unixTime) { + t.T = time.Unix(unix.second, unix.nano) +} + +// saveXmitTime is invoked by stateify. +func (t *TCPRACKState) saveXmitTime() unixTime { + return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (t *TCPRACKState) loadXmitTime(unix unixTime) { + t.XmitTime = time.Unix(unix.second, unix.nano) +} + +// saveLastSendTime is invoked by stateify. +func (t *TCPSenderState) saveLastSendTime() unixTime { + return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} +} + +// loadLastSendTime is invoked by stateify. +func (t *TCPSenderState) loadLastSendTime(unix unixTime) { + t.LastSendTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) saveRTTMeasureTime() unixTime { + return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { + t.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { + return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} +} + +// loadMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { + r.MeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { + return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { + r.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveSegTime is invoked by stateify. +func (t *TCPEndpointState) saveSegTime() unixTime { + return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} +} + +// loadSegTime is invoked by stateify. +func (t *TCPEndpointState) loadSegTime(unix unixTime) { + t.SegTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 3066f4ffd..80e8e0089 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -68,7 +68,7 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error { s.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { @@ -107,7 +107,7 @@ func (s *Stack) Option(option interface{}) tcpip.Error { s.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.ReceiveBufferSizeOption: s.mu.RLock() *v = s.receiveBufferSize s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 2814b94b4..d2c40cc43 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -39,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) @@ -137,11 +138,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data().TrimFront(fakeNetHeaderLen) + // DeleteFront invalidates slices. Make a copy before trimming. + nb := append([]byte(nil), hdr...) + pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -170,7 +173,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++ @@ -189,11 +192,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params return nil } - return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) + return f.nic.WritePacket(r, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -436,7 +439,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error } func send(r *stack.Route, payload buffer.View) tcpip.Error { - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), })) @@ -1461,7 +1464,7 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if n := ep.Drain(); n != 0 { t.Fatalf("got ep.Drain() = %d, want = 0", n) } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS, @@ -1645,10 +1648,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} - nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} - nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") + nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. s := stack.New(stack.Options{ @@ -2789,25 +2792,27 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") - ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") - toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 lifetimeSeconds = 9999 ) + var ( + linkLocalAddr1 = testutil.MustParse6("fe80::1") + linkLocalAddr2 = testutil.MustParse6("fe80::2") + linkLocalMulticastAddr = testutil.MustParse6("ff02::1") + uniqueLocalAddr1 = testutil.MustParse6("fc00::1") + uniqueLocalAddr2 = testutil.MustParse6("fd00::2") + globalAddr1 = testutil.MustParse6("a000::1") + globalAddr2 = testutil.MustParse6("a000::2") + globalAddr3 = testutil.MustParse6("a000::3") + ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1") + ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2") + toredoAddr1 = testutil.MustParse6("2001::1") + toredoAddr2 = testutil.MustParse6("2001::2") + ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1") + ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2") + ) + prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) @@ -3017,7 +3022,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -3354,21 +3359,21 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - rs stack.ReceiveBufferSizeOption + rs tcpip.ReceiveBufferSizeOption err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations - {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -3377,7 +3382,7 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { if err := s.SetOption(tc.rs); err != tc.err { t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if tc.err == nil { if err := s.Option(&rs); err != nil { t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) @@ -3448,7 +3453,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } ipv4Subnet := ipv4Addr.Subnet() ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4Gateway := testutil.MustParse4("192.168.1.1") ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ Address: "\xc0\xa8\x01\x3a", PrefixLen: 31, @@ -4352,13 +4357,15 @@ func TestWritePacketToRemote(t *testing.T) { func TestClearNeighborCacheOnNICDisable(t *testing.T) { const ( - nicID = 1 - - ipv4Addr = tcpip.Address("\x01\x02\x03\x04") - ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04") + nicID = 1 linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + var ( + ipv4Addr = testutil.MustParse4("1.2.3.4") + ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304") + ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go new file mode 100644 index 000000000..ddff6e2d6 --- /dev/null +++ b/pkg/tcpip/stack/tcp.go @@ -0,0 +1,451 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// TCPProbeFunc is the expected function type for a TCP probe function to be +// passed to stack.AddTCPProbe. +type TCPProbeFunc func(s TCPEndpointState) + +// TCPCubicState is used to hold a copy of the internal cubic state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPCubicState struct { + // WLastMax is the previous wMax value. + WLastMax float64 + + // WMax is the value of the congestion window at the time of the last + // congestion event. + WMax float64 + + // T is the time when the current congestion avoidance was entered. + T time.Time `state:".(unixTime)"` + + // TimeSinceLastCongestion denotes the time since the current + // congestion avoidance was entered. + TimeSinceLastCongestion time.Duration + + // C is the cubic constant as specified in RFC8312, page 11. + C float64 + + // K is the time period (in seconds) that the above function takes to + // increase the current window size to WMax if there are no further + // congestion events and is calculated using the following equation: + // + // K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5) + K float64 + + // Beta is the CUBIC multiplication decrease factor. That is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // WC(0)=WMax*beta_cubic. + Beta float64 + + // WC is window computed by CUBIC at time TimeSinceLastCongestion. It's + // calculated using the formula: + // + // WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1) + WC float64 + + // WEst is the window computed by CUBIC at time + // TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT). + WEst float64 +} + +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPRACKState struct { + // XmitTime is the transmission timestamp of the most recent + // acknowledged segment. + XmitTime time.Time `state:".(unixTime)"` + + // EndSequence is the ending TCP sequence number of the most recent + // acknowledged segment. + EndSequence seqnum.Value + + // FACK is the highest selectively or cumulatively acknowledged + // sequence. + FACK seqnum.Value + + // RTT is the round trip time of the most recently delivered packet on + // the connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + RTT time.Duration + + // Reord is true iff reordering has been detected on this connection. + Reord bool + + // DSACKSeen is true iff the connection has seen a DSACK. + DSACKSeen bool + + // ReoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + ReoWnd time.Duration + + // ReoWndIncr is the multiplier applied to adjust reorder window. + ReoWndIncr uint8 + + // ReoWndPersist is the number of loss recoveries before resetting + // reorder window. + ReoWndPersist int8 + + // RTTSeq is the SND.NXT when RTT is updated. + RTTSeq seqnum.Value +} + +// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. +// +// +stateify savable +type TCPEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// TCPFastRecoveryState holds a copy of the internal fast recovery state of a +// TCP endpoint. +// +// +stateify savable +type TCPFastRecoveryState struct { + // Active if true indicates the endpoint is in fast recovery. The + // following fields are only meaningful when Active is true. + Active bool + + // First is the first unacknowledged sequence number being recovered. + First seqnum.Value + + // Last is the 'recover' sequence number that indicates the point at + // which we should exit recovery barring any timeouts etc. + Last seqnum.Value + + // MaxCwnd is the maximum value we are permitted to grow the congestion + // window during recovery. This is set at the time we enter recovery. + // It exists to avoid attacks where the receiver intentionally sends + // duplicate acks to artificially inflate the sender's cwnd. + MaxCwnd int + + // HighRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. See: RFC 6675 Section 2 for + // details. + HighRxt seqnum.Value + + // RescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. See: RFC 6675 Section 2 for details. + RescueRxt seqnum.Value +} + +// TCPReceiverState holds a copy of the internal state of the receiver for a +// given TCP endpoint. +// +// +stateify savable +type TCPReceiverState struct { + // RcvNxt is the TCP variable RCV.NXT. + RcvNxt seqnum.Value + + // RcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to its + // peer that it's willing to accept. This may be different than RcvNxt + // + (last advertised receive window) if the receive window is reduced; + // in that case we have to reduce the window as we receive more data + // instead of shrinking it. + RcvAcc seqnum.Value + + // RcvWndScale is the window scaling to use for inbound segments. + RcvWndScale uint8 + + // PendingBufUsed is the number of bytes pending in the receive queue. + PendingBufUsed int +} + +// TCPRTTState holds a copy of information about the endpoint's round trip +// time. +// +// +stateify savable +type TCPRTTState struct { + // SRTT is the smoothed round trip time defined in section 2 of RFC + // 6298. + SRTT time.Duration + + // RTTVar is the round-trip time variation as defined in section 2 of + // RFC 6298. + RTTVar time.Duration + + // SRTTInited if true indicates that a valid RTT measurement has been + // completed. + SRTTInited bool +} + +// TCPSenderState holds a copy of the internal state of the sender for a given +// TCP Endpoint. +// +// +stateify savable +type TCPSenderState struct { + // LastSendTime is the timestamp at which we sent the last segment. + LastSendTime time.Time `state:".(unixTime)"` + + // DupAckCount is the number of Duplicate ACKs received. It is used for + // fast retransmit. + DupAckCount int + + // SndCwnd is the size of the sending congestion window in packets. + SndCwnd int + + // Ssthresh is the threshold between slow start and congestion + // avoidance. + Ssthresh int + + // SndCAAckCount is the number of packets acknowledged during + // congestion avoidance. When enough packets have been ack'd (typically + // cwnd packets), the congestion window is incremented by one. + SndCAAckCount int + + // Outstanding is the number of packets that have been sent but not yet + // acknowledged. + Outstanding int + + // SackedOut is the number of packets which have been selectively + // acked. + SackedOut int + + // SndWnd is the send window size in bytes. + SndWnd seqnum.Size + + // SndUna is the next unacknowledged sequence number. + SndUna seqnum.Value + + // SndNxt is the sequence number of the next segment to be sent. + SndNxt seqnum.Value + + // RTTMeasureSeqNum is the sequence number being used for the latest + // RTT measurement. + RTTMeasureSeqNum seqnum.Value + + // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Closed indicates that the caller has closed the endpoint for + // sending. + Closed bool + + // RTO is the retransmit timeout as defined in section of 2 of RFC + // 6298. + RTO time.Duration + + // RTTState holds information about the endpoint's round trip time. + RTTState TCPRTTState + + // MaxPayloadSize is the maximum size of the payload of a given + // segment. It is initialized on demand. + MaxPayloadSize int + + // SndWndScale is the number of bits to shift left when reading the + // send window size from a segment. + SndWndScale uint8 + + // MaxSentAck is the highest acknowledgement number sent till now. + MaxSentAck seqnum.Value + + // FastRecovery holds the fast recovery state for the endpoint. + FastRecovery TCPFastRecoveryState + + // Cubic holds the state related to CUBIC congestion control. + Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState +} + +// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. +// +// +stateify savable +type TCPSACKInfo struct { + // Blocks is the list of SACK Blocks that identify the out of order + // segments held by a given TCP endpoint. + Blocks []header.SACKBlock + + // ReceivedBlocks are the SACK blocks received by this endpoint from + // the peer endpoint. + ReceivedBlocks []header.SACKBlock + + // MaxSACKED is the highest sequence number that has been SACKED by the + // peer. + MaxSACKED seqnum.Value +} + +// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. +// +// +stateify savable +type RcvBufAutoTuneParams struct { + // MeasureTime is the time at which the current measurement was + // started. + MeasureTime time.Time `state:".(unixTime)"` + + // CopiedBytes is the number of bytes copied to user space since this + // measure began. + CopiedBytes int + + // PrevCopiedBytes is the number of bytes copied to userspace in the + // previous RTT period. + PrevCopiedBytes int + + // RcvBufSize is the auto tuned receive buffer size. + RcvBufSize int + + // RTT is the smoothed RTT as measured by observing the time between + // when a byte is first acknowledged and the receipt of data that is at + // least one window beyond the sequence number that was acknowledged. + RTT time.Duration + + // RTTVar is the "round-trip time variation" as defined in section 2 of + // RFC6298. + RTTVar time.Duration + + // RTTMeasureSeqNumber is the highest acceptable sequence number at the + // time this RTT measurement period began. + RTTMeasureSeqNumber seqnum.Value + + // RTTMeasureTime is the absolute time at which the current RTT + // measurement period began. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Disabled is true if an explicit receive buffer is set for the + // endpoint. + Disabled bool +} + +// TCPRcvBufState contains information about the state of an endpoint's receive +// socket buffer. +// +// +stateify savable +type TCPRcvBufState struct { + // RcvBufUsed is the amount of bytes actually held in the receive + // socket buffer for the endpoint. + RcvBufUsed int + + // RcvBufAutoTuneParams is used to hold state variables to compute the + // auto tuned receive buffer size. + RcvAutoParams RcvBufAutoTuneParams + + // RcvClosed if true, indicates the endpoint has been closed for + // reading. + RcvClosed bool +} + +// TCPSndBufState contains information about the state of an endpoint's send +// socket buffer. +// +// +stateify savable +type TCPSndBufState struct { + // SndBufSize is the size of the socket send buffer. + SndBufSize int + + // SndBufUsed is the number of bytes held in the socket send buffer. + SndBufUsed int + + // SndClosed indicates that the endpoint has been closed for sends. + SndClosed bool + + // SndBufInQueue is the number of bytes in the send queue. + SndBufInQueue seqnum.Size + + // PacketTooBigCount is used to notify the main protocol routine how + // many times a "packet too big" control packet is received. + PacketTooBigCount int + + // SndMTU is the smallest MTU seen in the control packets received. + SndMTU int +} + +// TCPEndpointStateInner contains the members of TCPEndpointState used directly +// (that is, not within another containing struct) within the endpoint's +// internal implementation. +// +// +stateify savable +type TCPEndpointStateInner struct { + // TSOffset is a randomized offset added to the value of the TSVal + // field in the timestamp option. + TSOffset uint32 + + // SACKPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + SACKPermitted bool + + // SendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1. + SendTSOk bool + + // RecentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field + // is updated if required when a new segment is received by this + // endpoint. + RecentTS uint32 +} + +// TCPEndpointState is a copy of the internal state of a TCP endpoint. +// +// +stateify savable +type TCPEndpointState struct { + // TCPEndpointStateInner contains the members of TCPEndpointState used + // by the endpoint's internal implementation. + TCPEndpointStateInner + + // ID is a copy of the TransportEndpointID for the endpoint. + ID TCPEndpointID + + // SegTime denotes the absolute time when this segment was received. + SegTime time.Time `state:".(unixTime)"` + + // RcvBufState contains information about the state of the endpoint's + // receive socket buffer. + RcvBufState TCPRcvBufState + + // SndBufState contains information about the state of the endpoint's + // send socket buffer. + SndBufState TCPSndBufState + + // SACK holds TCP SACK related information for this endpoint. + SACK TCPSACKInfo + + // Receiver holds variables related to the TCP receiver for the + // endpoint. + Receiver TCPReceiverState + + // Sender holds state related to the TCP Sender for the endpoint. + Sender TCPSenderState +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e188efccb..80ad1a9d4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { return eps } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { +// handlePacket is called by the stack when new packets arrive to this transport +// endpoint. It returns false if the packet could not be matched to any +// transport endpoint, true otherwise. +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return false } } @@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return true } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() - return + return true } transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. + return true } // handleError delivers an error to the transport endpoint identified by id. @@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, } return false } - ep.handlePacket(id, pkt) - return true + return ep.handlePacket(id, pkt) } // deliverRawPacket attempts to deliver the given packet and returns whether it diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 054cced0c..839178809 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -70,7 +70,7 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} - ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) return ep } @@ -106,7 +106,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions Data: buffer.View(v).ToVectorisedView(), }) _ = pkt.TransportHeader().Push(fakeTransHeaderLen) - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, err } @@ -233,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress(), route: route, } - ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go new file mode 100644 index 000000000..7ce43a68e --- /dev/null +++ b/pkg/tcpip/stdclock.go @@ -0,0 +1,130 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +// stdClock implements Clock with the time package. +// +// +stateify savable +type stdClock struct { + // baseTime holds the time when the clock was constructed. + // + // This value is used to calculate the monotonic time from the time package. + // As per https://golang.org/pkg/time/#hdr-Monotonic_Clocks, + // + // Operating systems provide both a “wall clock,” which is subject to + // changes for clock synchronization, and a “monotonic clock,” which is not. + // The general rule is that the wall clock is for telling time and the + // monotonic clock is for measuring time. Rather than split the API, in this + // package the Time returned by time.Now contains both a wall clock reading + // and a monotonic clock reading; later time-telling operations use the wall + // clock reading, but later time-measuring operations, specifically + // comparisons and subtractions, use the monotonic clock reading. + // + // ... + // + // If Times t and u both contain monotonic clock readings, the operations + // t.After(u), t.Before(u), t.Equal(u), and t.Sub(u) are carried out using + // the monotonic clock readings alone, ignoring the wall clock readings. If + // either t or u contains no monotonic clock reading, these operations fall + // back to using the wall clock readings. + // + // Given the above, we can safely conclude that time.Since(baseTime) will + // return monotonically increasing values if we use time.Now() to set baseTime + // at the time of clock construction. + // + // Note that time.Since(t) is shorthand for time.Now().Sub(t), as per + // https://golang.org/pkg/time/#Since. + baseTime time.Time `state:"nosave"` + + // monotonicOffset is the offset applied to the calculated monotonic time. + // + // monotonicOffset is assigned maxMonotonic after restore so that the + // monotonic time will continue from where it "left off" before saving as part + // of S/R. + monotonicOffset int64 `state:"nosave"` + + // monotonicMU protects maxMonotonic. + monotonicMU sync.Mutex `state:"nosave"` + maxMonotonic int64 +} + +// NewStdClock returns an instance of a clock that uses the time package. +func NewStdClock() Clock { + return &stdClock{ + baseTime: time.Now(), + } +} + +var _ Clock = (*stdClock)(nil) + +// NowNanoseconds implements Clock.NowNanoseconds. +func (*stdClock) NowNanoseconds() int64 { + return time.Now().UnixNano() +} + +// NowMonotonic implements Clock.NowMonotonic. +func (s *stdClock) NowMonotonic() int64 { + sinceBase := time.Since(s.baseTime) + if sinceBase < 0 { + panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) + } + + monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + + // Monotonic time values must never decrease. + if monotonicValue > s.maxMonotonic { + s.maxMonotonic = monotonicValue + } + + return s.maxMonotonic +} + +// AfterFunc implements Clock.AfterFunc. +func (*stdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/stdclock_state.go index d0f58cfaf..795db9181 100644 --- a/pkg/tcpip/transport/tcp/cubic_state.go +++ b/pkg/tcpip/stdclock_state.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tcp +package tcpip -import ( - "time" -) +import "time" -// saveT is invoked by stateify. -func (c *cubicState) saveT() unixTime { - return unixTime{c.t.Unix(), c.t.UnixNano()} -} +// afterLoad is invoked by stateify. +func (s *stdClock) afterLoad() { + s.baseTime = time.Now() -// loadT is invoked by stateify. -func (c *cubicState) loadT(unix unixTime) { - c.t = time.Unix(unix.second, unix.nano) + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + s.monotonicOffset = s.maxMonotonic } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 87ea09a5e..d5f941c5f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -73,7 +73,7 @@ type Clock interface { // nanoseconds since the Unix epoch. NowNanoseconds() int64 - // NowMonotonic returns a monotonic time value. + // NowMonotonic returns a monotonic time value at nanosecond resolution. NowMonotonic() int64 // AfterFunc waits for the duration to elapse and then calls f in its own @@ -691,10 +691,6 @@ const ( // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption - // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to - // specify the receive buffer size option. - ReceiveBufferSizeOption - // SendQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the output buffer should be returned. SendQueueSizeOption @@ -786,6 +782,13 @@ func (*TCPRecovery) isGettableTransportProtocolOption() {} func (*TCPRecovery) isSettableTransportProtocolOption() {} +// TCPAlwaysUseSynCookies indicates unconditional usage of syncookies. +type TCPAlwaysUseSynCookies bool + +func (*TCPAlwaysUseSynCookies) isGettableTransportProtocolOption() {} + +func (*TCPAlwaysUseSynCookies) isSettableTransportProtocolOption() {} + const ( // TCPRACKLossDetection indicates RACK is used for loss detection and // recovery. @@ -1020,19 +1023,6 @@ func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} -// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify -// the number of endpoints that can be in SYN-RCVD state before the stack -// switches to using SYN cookies. -type TCPSynRcvdCountThresholdOption uint64 - -func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} - // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 @@ -1117,6 +1107,7 @@ const ( // LingerOption is used by SetSockOpt/GetSockOpt to set/get the // duration for which a socket lingers before returning from Close. // +// +marshal // +stateify savable type LingerOption struct { Enabled bool @@ -1150,6 +1141,19 @@ type SendBufferSizeOption struct { Max int } +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + // Min is the minimum size for send buffer. + Min int + + // Default is the default size for send buffer. + Default int + + // Max is the maximum size for send buffer. + Max int +} + // GetSendBufferLimits is used to get the send buffer size limits. type GetSendBufferLimits func(StackHandler) SendBufferSizeOption @@ -1162,6 +1166,18 @@ func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption { return ss } +// GetReceiveBufferLimits is used to get the send buffer size limits. +type GetReceiveBufferLimits func(StackHandler) ReceiveBufferSizeOption + +// GetStackReceiveBufferLimits is used to get default, min and max send buffer size. +func GetStackReceiveBufferLimits(so StackHandler) ReceiveBufferSizeOption { + var ss ReceiveBufferSizeOption + if err := so.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + return ss +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -1218,7 +1234,7 @@ func (s *StatCounter) Decrement() { } // Value returns the current value of the counter. -func (s *StatCounter) Value() uint64 { +func (s *StatCounter) Value(name ...string) uint64 { return atomic.LoadUint64(&s.count) } @@ -1512,6 +1528,30 @@ type IGMPStats struct { // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats) } +// IPForwardingStats collects stats related to IP forwarding (both v4 and v6). +type IPForwardingStats struct { + // Unrouteable is the number of IP packets received which were dropped + // because the netstack could not construct a route to their + // destination. + Unrouteable *StatCounter + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL *StatCounter + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource *StatCounter + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination *StatCounter + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors *StatCounter +} + // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // LINT.IfChange(IPStats) @@ -1562,6 +1602,10 @@ type IPStats struct { // chain. IPTablesOutputDropped *StatCounter + // IPTablesPostroutingDropped is the number of IP packets dropped in the + // Postrouting chain. + IPTablesPostroutingDropped *StatCounter + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out // of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. @@ -1576,6 +1620,9 @@ type IPStats struct { // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived *StatCounter + // Forwarding collects stats related to IP forwarding. + Forwarding IPForwardingStats + // LINT.ThenChange(network/internal/ip/stats.go:MultiCounterIPStats) } @@ -1734,6 +1781,10 @@ type TCPStats struct { // ChecksumErrors is the number of segments dropped due to bad checksums. ChecksumErrors *StatCounter + + // FailedPortReservations is the number of times TCP failed to reserve + // a port. + FailedPortReservations *StatCounter } // UDPStats collects UDP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 3cc8c36f1..d4f7bb5ff 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -9,11 +9,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -78,6 +81,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", @@ -101,6 +105,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -123,6 +128,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index d10ae05c2..dbd279c94 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -21,11 +21,14 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -312,3 +315,194 @@ func TestForwarding(t *testing.T) { }) } } + +func TestMulticastForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + ttl = 64 + ) + + var ( + ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") + ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + + ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") + ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") + ) + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) + } + + v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) + } + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + expectForward bool + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 link-local multicast destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local source", + srcAddr: ipv4LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv4Addr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalUnicastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 non-link-local unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 non-link-local multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 link-local multicast destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local source", + srcAddr: ipv6LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv6Addr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalUnicastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 non-link-local unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 non-link-local multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + } + + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID2, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID2, + }, + }) + + test.rx(e1, test.srcAddr, test.dstAddr) + + p, ok := e2.Read() + if ok != test.expectForward { + t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward) + } + + if test.expectForward { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 1cfd854a0..c61d4e788 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -347,7 +347,7 @@ type channelEndpointWithoutWritePacket struct { t *testing.T } -func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") return &tcpip.ErrNotSupported{} } @@ -627,7 +627,7 @@ func TestIPTableWritePackets(t *testing.T) { pkts := test.genPacket(r) pktsLen := pkts.Len() - if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{ + if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ Protocol: header.UDPProtocolNumber, TTL: 64, }); err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index d39809e1c..c657714ba 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -687,10 +687,10 @@ func TestWritePacketsLinkResolution(t *testing.T) { TOS: stack.DefaultTOS, } - if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { - t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) + if n, err := r.WritePackets(pkts, params); err != nil { + t.Fatalf("r.WritePackets(_, %#v): %s", params, err) } else if want := pkts.Len(); want != n { - t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) + t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want) } var writer bytes.Buffer diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 2c538a43e..3df1bbd68 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -314,11 +315,11 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { TOS: stack.DefaultTOS, } data := buffer.View([]byte{1, 2, 3, 4}) - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + if err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: data.ToVectorisedView(), })); err != nil { - t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err) + t.Fatalf("r.WritePacket(%#v, _): %s", params, err) } // Removing the address should make the endpoint invalid. @@ -326,12 +327,12 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) } { - err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: data.ToVectorisedView(), })) if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) + t.Fatalf("got r.WritePacket(%#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) } } } @@ -510,25 +511,25 @@ func TestExternalLoopbackTraffic(t *testing.T) { nicID1 = 1 nicID2 = 2 - ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") - numPackets = 1 + ttl = 64 ) + ipv4Loopback := testutil.MustParse4("127.0.0.1") loopbackSourcedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl) } loopbackSourcedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl) } loopbackDestinedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl) } loopbackDestinedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl) } invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index c6a9c2393..2d0a6e6a7 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -43,12 +44,15 @@ const ( // to a multicast or broadcast address uses a unicast source address for the // reply. func TestPingMulticastBroadcast(t *testing.T) { - const nicID = 1 + const ( + nicID = 1 + ttl = 64 + ) tests := []struct { name string protoNum tcpip.NetworkProtocolNumber - rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8) srcAddr tcpip.Address dstAddr tcpip.Address expectedSrc tcpip.Address @@ -136,7 +140,7 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - test.rxICMP(e, test.srcAddr, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr, ttl) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") @@ -435,10 +439,10 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { // interested endpoints. func TestReuseAddrAndBroadcast(t *testing.T) { const ( - nicID = 1 - localPort = 9000 - loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + nicID = 1 + localPort = 9000 ) + loopbackBroadcast := testutil.MustParse4("127.255.255.255") tests := []struct { name string diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 78244f4eb..ac3c703d4 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -40,13 +41,13 @@ import ( // This tests that a local route is created and packets do not leave the stack. func TestLocalPing(t *testing.T) { const ( - nicID = 1 - ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + nicID = 1 // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo // request/reply packets. icmpDataOffset = 8 ) + ipv4Loopback := testutil.MustParse4("127.0.0.1") channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index d1c9f3a94..8fd9be32b 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -48,10 +48,6 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) -const ( - ttl = 255 -) - // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -322,7 +318,7 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. // RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on // the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -347,7 +343,7 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { // RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on // the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD new file mode 100644 index 000000000..472545a5d --- /dev/null +++ b/pkg/tcpip/testutil/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + testonly = True, + srcs = ["testutil.go"], + visibility = ["//visibility:public"], + deps = ["//pkg/tcpip"], +) + +go_test( + name = "testutil_test", + srcs = ["testutil_test.go"], + library = ":testutil", + deps = ["//pkg/tcpip"], +) diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go new file mode 100644 index 000000000..1aaed590f --- /dev/null +++ b/pkg/tcpip/testutil/testutil.go @@ -0,0 +1,43 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil provides helper functions for netstack unit tests. +package testutil + +import ( + "fmt" + "net" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address. +// Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes. +func MustParse4(addr string) tcpip.Address { + ip := net.ParseIP(addr).To4() + if ip == nil { + panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr)) + } + return tcpip.Address(ip) +} + +// MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing +// an IPv4 address will yield an IPv4-mapped IPv6 address. +func MustParse6(addr string) tcpip.Address { + ip := net.ParseIP(addr).To16() + if ip == nil { + panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr)) + } + return tcpip.Address(ip) +} diff --git a/pkg/tcpip/testutil/testutil_test.go b/pkg/tcpip/testutil/testutil_test.go new file mode 100644 index 000000000..6aad9585d --- /dev/null +++ b/pkg/tcpip/testutil/testutil_test.go @@ -0,0 +1,103 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// Who tests the testutils? + +func TestMustParse4(t *testing.T) { + tcs := []struct { + str string + addr tcpip.Address + shouldPanic bool + }{ + { + str: "127.0.0.1", + addr: "\x7f\x00\x00\x01", + }, { + str: "", + shouldPanic: true, + }, { + str: "fe80::1", + shouldPanic: true, + }, { + // In an ideal world this panics too, but net.IP + // doesn't distinguish between IPv4 and IPv4-mapped + // addresses. + str: "::ffff:0.0.0.1", + addr: "\x00\x00\x00\x01", + }, + } + + for _, tc := range tcs { + t.Run(tc.str, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("panic expected, but did not occur") + } + }() + } + if got := MustParse4(tc.str); got != tc.addr { + t.Errorf("got MustParse4(%s) = %s, want = %s", tc.str, got, tc.addr) + } + }) + } +} + +func TestMustParse6(t *testing.T) { + tcs := []struct { + str string + addr tcpip.Address + shouldPanic bool + }{ + { + // In an ideal world this panics too, but net.IP + // doesn't distinguish between IPv4 and IPv4-mapped + // addresses. + str: "127.0.0.1", + addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x7f\x00\x00\x01", + }, { + str: "", + shouldPanic: true, + }, { + str: "fe80::1", + addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + }, { + str: "::ffff:0.0.0.1", + addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01", + }, + } + + for _, tc := range tcs { + t.Run(tc.str, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("panic expected, but did not occur") + } + }() + } + if got := MustParse6(tc.str); got != tc.addr { + t.Errorf("got MustParse6(%s) = %s, want = %s", tc.str, got, tc.addr) + } + }) + } +} diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go deleted file mode 100644 index eeea97b12..000000000 --- a/pkg/tcpip/time_unsafe.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.9 -// +build !go1.18 - -// Check go:linkname function signatures when updating Go version. - -package tcpip - -import ( - "time" // Used with go:linkname. - _ "unsafe" // Required for go:linkname. -) - -// StdClock implements Clock with the time package. -// -// +stateify savable -type StdClock struct{} - -var _ Clock = (*StdClock)(nil) - -//go:linkname now time.now -func now() (sec int64, nsec int32, mono int64) - -// NowNanoseconds implements Clock.NowNanoseconds. -func (*StdClock) NowNanoseconds() int64 { - sec, nsec, _ := now() - return sec*1e9 + int64(nsec) -} - -// NowMonotonic implements Clock.NowMonotonic. -func (*StdClock) NowMonotonic() int64 { - _, _, mono := now() - return mono -} - -// AfterFunc implements Clock.AfterFunc. -func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { - return &stdTimer{ - t: time.AfterFunc(d, f), - } -} - -type stdTimer struct { - t *time.Timer -} - -var _ Timer = (*stdTimer)(nil) - -// Stop implements Timer.Stop. -func (st *stdTimer) Stop() bool { - return st.t.Stop() -} - -// Reset implements Timer.Reset. -func (st *stdTimer) Reset(d time.Duration) { - st.t.Reset(d) -} - -// NewStdTimer returns a Timer implemented with the time package. -func NewStdTimer(t *time.Timer) Timer { - return &stdTimer{t: t} -} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index a82384c49..1633d0aeb 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -29,7 +29,7 @@ const ( ) func TestJobReschedule(t *testing.T) { - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var wg sync.WaitGroup var lock sync.Mutex @@ -43,7 +43,7 @@ func TestJobReschedule(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { wg.Done() }) job.Schedule(shortDuration) @@ -56,11 +56,11 @@ func TestJobReschedule(t *testing.T) { func TestJobExecution(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) @@ -83,11 +83,11 @@ func TestJobExecution(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(middleDuration) lock.Lock() @@ -114,12 +114,12 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -151,13 +151,13 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -174,12 +174,12 @@ func TestJobImmediatelyCancel(t *testing.T) { func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -206,12 +206,12 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. @@ -239,12 +239,12 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { job.Cancel() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 50991c3c0..8afde7fca 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -63,12 +63,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList icmpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList icmpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -84,6 +83,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { @@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - state: stateInitial, - uniqueID: s.UniqueID(), + waiterQueue: waiterQueue, + state: stateInitial, + uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } + var rs tcpip.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) + } return ep, nil } @@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.rcvMu.Lock() v := int(e.ttl) @@ -430,7 +431,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() return err } @@ -477,7 +478,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() } @@ -746,8 +747,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -755,8 +756,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index a3c6db5a8..28a56a2d5 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) if e.state != stateBound && e.state != stateConnected { return diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 52ed9560c..496eca581 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -72,11 +72,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -91,6 +90,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a new packet endpoint. @@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - ep.rcvBufSizeMax = rs.Default + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { @@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := ep.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - ep.rcvMu.Lock() - ep.rcvBufSizeMax = v - ep.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } func (ep *endpoint) LastError() tcpip.Error { @@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - v := ep.rcvBufSizeMax - ep.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, return } - if ep.rcvBufSize >= ep.rcvBufSizeMax { + rcvBufSize := ep.ops.GetReceiveBufferSize() + if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (ep *endpoint) freeze() { + ep.mu.Lock() + ep.frozen = true + ep.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (ep *endpoint) thaw() { + ep.mu.Lock() + ep.frozen = false + ep.mu.Unlock() +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index ece662c0d..5bd860d20 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - ep.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - ep.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max + ep.freeze() } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { + ep.thaw() ep.stack = stack.StackFromEnv - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e27a249cd..bcec3d2e7 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,7 +26,6 @@ package raw import ( - "fmt" "io" "gvisor.dev/gvisor/pkg/sync" @@ -69,11 +68,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList rawPacketList - rcvBufSize int - rcvBufSizeMax int `state:".(int)"` - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList rawPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -89,6 +87,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a raw endpoint for the given protocols. @@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - associated: associated, + waiterQueue: waiterQueue, + associated: associated, } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } // Unassociated endpoints are write-only and users call Write() with IP // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - e.rcvBufSizeMax = 0 + e.ops.SetReceiveBufferSize(0, false) e.waiterQueue = nil return e, nil } @@ -352,7 +354,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp Data: buffer.View(payloadBytes).ToVectorisedView(), }) pkt.Owner = owner - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := route.WritePacket(stack.NetworkHeaderParams{ Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, @@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - e.rcvMu.Lock() - e.rcvBufSizeMax = v - e.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() @@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 263ec5146..5d6f2709c 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // If the endpoint is connected, re-connect. if e.connected { diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index a69d6624d..48417f192 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -34,14 +34,12 @@ go_library( "connect.go", "connect_unsafe.go", "cubic.go", - "cubic_state.go", "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", "protocol.go", "rack.go", - "rack_state.go", "rcv.go", "rcv_state.go", "reno.go", @@ -107,6 +105,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/tcp/testing/context", "//pkg/test/testutil", "//pkg/waiter", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 025b134e2..d4bd4e80e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -51,11 +50,6 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 - - // SynRcvdCountThreshold is the default global maximum number of - // connections that are allowed to be in SYN-RCVD state before TCP - // starts using SYN cookies to accept connections. - SynRcvdCountThreshold uint64 = 1000 ) var ( @@ -80,9 +74,6 @@ func encodeMSS(mss uint16) uint32 { type listenContext struct { stack *stack.Stack - // synRcvdCount is a reference to the stack level synRcvdCount. - synRcvdCount *synRcvdCounter - // rcvWnd is the receive window that is sent by this listening context // in the initial SYN-ACK. rcvWnd seqnum.Size @@ -138,14 +129,12 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } - p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) - if !ok { - panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) - } - l.synRcvdCount = p.SynRcvdCounter() - rand.Read(l.nonce[0][:]) - rand.Read(l.nonce[1][:]) + for i := range l.nonce { + if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil { + panic(err) + } + } return l } @@ -163,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // Feed everything to the hasher. l.hasherMu.Lock() l.hasher.Reset() + + // Per hash.Hash.Writer: + // + // It never returns an error. l.hasher.Write(payload[:]) l.hasher.Write(l.nonce[nonceIndex][:]) - io.WriteString(l.hasher, string(id.LocalAddress)) - io.WriteString(l.hasher, string(id.RemoteAddress)) + l.hasher.Write([]byte(id.LocalAddress)) + l.hasher.Write([]byte(id.RemoteAddress)) // Finalize the calculation of the hash and return the first 4 bytes. - h := make([]byte, 0, sha1.Size) - h = l.hasher.Sum(h) + h := l.hasher.Sum(nil) l.hasherMu.Unlock() return binary.BigEndian.Uint32(h[:]) @@ -199,9 +191,17 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } +func (l *listenContext) useSynCookies() bool { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) +} + // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -215,11 +215,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n := newEndpoint(l.stack, netProto, queue) n.ops.SetV6Only(l.v6Only) - n.ID = s.id + n.TransportEndpointInfo.ID = s.id n.boundNICID = s.nicID n.route = route n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} - n.rcvBufSize = int(l.rcvWnd) + n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -231,7 +231,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // Bootstrap the auto tuning algorithm. Starting at zero will result in // a large step function on the first window adjustment causing the // window to grow to a really large value. - n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + n.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = n.initialReceiveWindow() return n, nil } @@ -248,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err } @@ -290,7 +290,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint( + ep.effectiveNetProtos, + ProtocolNumber, + ep.TransportEndpointInfo.ID, + ep, + ep.boundPortFlags, + ep.boundBindToDevice, + ); err != nil { ep.mu.Unlock() ep.Close() @@ -307,6 +314,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Initialize and start the handshake. h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + h.listenEP = l.listenEP h.start() return h, nil } @@ -334,14 +342,14 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.ID] = n + l.pendingEndpoints[n.TransportEndpointInfo.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.ID) + delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -382,39 +390,46 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window // scaling. - e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + e.rcv.RcvWndScale = e.h.effectiveRcvWndScale() // Clean up handshake state stored in the endpoint so that it can be GCed. e.h = nil } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state (acceptedChan is nil), -// the new endpoint is closed instead. +// listener has transitioned out of the listen state (accepted is the zero +// value), the new endpoint is reset instead. func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) e.mu.Unlock() defer e.pendingAccepted.Done() - e.acceptMu.Lock() - for { - if e.acceptedChan == nil { - e.acceptMu.Unlock() - n.notifyProtocolGoroutine(notifyReset) - return - } - select { - case e.acceptedChan <- n: + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + if e.accepted == (accepted{}) { + return false + } + if e.accepted.endpoints.Len() == e.accepted.cap { + e.acceptCond.Wait() + continue + } + + e.accepted.endpoints.PushBack(n) if !withSynCookie { atomic.AddInt32(&e.synRcvdCount, -1) } - e.acceptMu.Unlock() - e.waiterQueue.Notify(waiter.ReadableEvents) - return - default: - e.acceptCond.Wait() + return true } + }() + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + n.notifyProtocolGoroutine(notifyReset) } } @@ -436,17 +451,21 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // * propagateInheritableOptionsLocked has been called. // * e.mu is held. func (e *endpoint) reserveTupleLocked() bool { - dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + dest := tcpip.FullAddress{ + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, + } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: dest, } if !e.stack.ReserveTuple(portRes) { + e.stack.Stats().TCP.FailedPortReservations.Increment() return false } @@ -485,7 +504,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } go func() { - defer ctx.synRcvdCount.dec() if err := h.complete(); err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -497,24 +515,29 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(h.ep, false /*withSynCookie*/) - }() // S/R-SAFE: synRcvdCount is the barrier. + }() return nil } -func (e *endpoint) incSynRcvdCount() bool { +func (e *endpoint) synRcvdBacklogFull() bool { e.acceptMu.Lock() - canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) + acceptedCap := e.accepted.cap e.acceptMu.Unlock() - if canInc { - atomic.AddInt32(&e.synRcvdCount, 1) - } - return canInc + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + // + // We maintain an equality check here as the synRcvdCount is incremented + // and compared only from a single listener context and the capacity of + // the accepted queue can only increase by a new listen call. + return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) + full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap e.acceptMu.Unlock() return full } @@ -524,9 +547,9 @@ func (e *endpoint) acceptQueueIsFull() bool { // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // @@ -538,69 +561,55 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err switch { case s.flags == header.TCPFlagSyn: - opts := parseSynSegmentOptions(s) - if ctx.synRcvdCount.inc() { - // Only handle the syn if the following conditions hold - // - accept queue is not full. - // - number of connections in synRcvd state is less than the - // backlog. - if !e.acceptQueueIsFull() && e.incSynRcvdCount() { - s.incRef() - _ = e.handleSynSegment(ctx, s, &opts) - return nil - } - ctx.synRcvdCount.dec() + if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return nil - } else { - // If cookies are in use but the endpoint accept queue - // is full then drop the syn. - if e.acceptQueueIsFull() { - e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() - e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() - e.stack.Stats().DroppedPackets.Increment() - return nil - } - cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + } - route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) - if err != nil { - return err - } - defer route.Release() + opts := parseSynSegmentOptions(s) + if !ctx.useSynCookies() { + s.incRef() + atomic.AddInt32(&e.synRcvdCount, 1) + return e.handleSynSegment(ctx, s, &opts) + } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() - // Send SYN without window scaling because we currently - // don't encode this information in the cookie. - // - // Enable Timestamp option if the original syn did have - // the timestamp option specified. - // - // Use the user supplied MSS on the listening socket for - // new connections, if available. - synOpts := header.TCPSynOptions{ - WS: -1, - TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), - TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, route), - } - fields := tcpFields{ - id: s.id, - ttl: e.ttl, - tos: e.sendTOS, - flags: header.TCPFlagSyn | header.TCPFlagAck, - seq: cookie, - ack: s.sequenceNumber + 1, - rcvWnd: ctx.rcvWnd, - } - if err := e.sendSynTCP(route, fields, synOpts); err != nil { - return err - } - e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() - return nil + // Send SYN without window scaling because we currently + // don't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSEcr: opts.TSVal, + MSS: calculateAdvertisedMSS(e.userMSS, route), + } + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + fields := tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + } + if err := e.sendSynTCP(route, fields, synOpts); err != nil { + return err } + e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { @@ -615,25 +624,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } - if !ctx.synRcvdCount.synCookiesInUse() { - // When not using SYN cookies, as per RFC 793, section 3.9, page 64: - // Any acknowledgment is bad if it arrives on a connection still in - // the LISTEN state. An acceptable reset segment should be formed - // for any arriving ACK-bearing segment. The RST should be - // formatted as follows: - // - // <SEQ=SEG.ACK><CTL=RST> - // - // Send a reset as this is an ACK for which there is no - // half open connections and we are not using cookies - // yet. - // - // The only time we should reach here when a connection - // was opened and closed really quickly and a delayed - // ACK was received from the sender. - return replyWithReset(e.stack, s, e.sendTOS, e.ttl) - } - iss := s.ackNumber - 1 irs := s.sequenceNumber - 1 @@ -651,7 +641,23 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return nil + + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // + // Send a reset as this is an ACK for which there is no + // half open connections and we are not using cookies + // yet. + // + // The only time we should reach here when a connection + // was opened and closed really quickly and a delayed + // ACK was received from the sender. + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -672,7 +678,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { return err } @@ -693,7 +699,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint( + n.effectiveNetProtos, + ProtocolNumber, + n.TransportEndpointInfo.ID, + n, + n.boundPortFlags, + n.boundBindToDevice, + ); err != nil { n.mu.Unlock() n.Close() @@ -708,7 +721,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was // sent above. - n.tsOffset = 0 + n.TSOffset = 0 // Switch state to connected. n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a9e978cf6..5e03e7715 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -65,11 +65,12 @@ const ( // NOTE: handshake.ep.mu is held during handshake processing. It is released if // we are going to block and reacquired when we start processing an event. type handshake struct { - ep *endpoint - state handshakeState - active bool - flags header.TCPFlags - ackNum seqnum.Value + ep *endpoint + listenEP *endpoint + state handshakeState + active bool + flags header.TCPFlags + ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. iss seqnum.Value @@ -155,7 +156,7 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the @@ -301,7 +302,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { ttl = h.ep.route.DefaultTTL() } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -357,14 +358,14 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, - TS: h.ep.sendTSOk, + TS: h.ep.SendTSOk, TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), - SACKPermitted: h.ep.sackPermitted, + SACKPermitted: h.ep.SACKPermitted, MSS: h.ep.amss, } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -389,13 +390,22 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. - if h.ep.sendTSOk && !s.parsedOptions.TS { + if h.ep.SendTSOk && !s.parsedOptions.TS { h.ep.stack.Stats().DroppedPackets.Increment() return nil } + // Drop the ACK if the accept queue is full. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_ipv4.c#L1523 + // We could abort the connection as well with a tunable as in + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_minisocks.c#L788 + if listenEP := h.listenEP; listenEP != nil && listenEP.acceptQueueIsFull() { + listenEP.stack.Stats().DroppedPackets.Increment() + return nil + } + // Update timestamp if required. See RFC7323, section-4.3. - if h.ep.sendTSOk && s.parsedOptions.TS { + if h.ep.SendTSOk && s.parsedOptions.TS { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted @@ -485,8 +495,8 @@ func (h *handshake) start() { // start() is also called in a listen context so we want to make sure we only // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { - synOpts.TS = h.ep.sendTSOk - synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + synOpts.TS = h.ep.SendTSOk + synOpts.SACKPermitted = h.ep.SACKPermitted && bool(sackEnabled) if h.sndWndScale < 0 { // Disable window scaling if the peer did not send us // the window scaling option. @@ -496,7 +506,7 @@ func (h *handshake) start() { h.sendSYNOpts = synOpts h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -544,7 +554,7 @@ func (h *handshake) complete() tcpip.Error { // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -576,8 +586,14 @@ func (h *handshake) complete() tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + // Check for any ICMP errors notified to us. if n¬ifyError != 0 { - return h.ep.lastErrorLocked() + if err := h.ep.lastErrorLocked(); err != nil { + return err + } + // Flag the handshake failure as aborted if the lastError is + // cleared because of a socket layer call. + return &tcpip.ErrConnectionAborted{} } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -711,14 +727,14 @@ type tcpFields struct { func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. - if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { + if err := e.sendTCP(r, tf, buffer.VectorisedView{}, stack.GSO{}); err != nil { e.stats.SendErrors.SynSendToNetworkFailed.Increment() } putOptions(tf.opts) return nil } -func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO) tcpip.Error { tf.txHash = e.txHash if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() @@ -728,7 +744,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV return nil } -func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { +func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso stack.GSO) { optLen := len(tf.opts) tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) pkt.TransportProtocolNumber = header.TCPProtocolNumber @@ -745,7 +761,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. - if gso != nil && gso.NeedsCsum { + if gso.Type != stack.GSONone && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We // calculate a checksum of the pseudo-header and save it in the // TCP header, then the kernel calculate a checksum of the @@ -757,7 +773,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error { // We need to shallow clone the VectorisedView here as ReadToView will // split the VectorisedView and Trim underlying views as it splits. Not // doing the clone here will cause the underlying views of data itself @@ -789,13 +805,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Data().ReadFromVV(&data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkt.GSOOptions = gso pkts.PushBack(pkt) } if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } - sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) + sent, err := r.WritePackets(pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) if err != nil { r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) } @@ -805,13 +822,13 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > math.MaxUint16 { tf.rcvWnd = math.MaxUint16 } - if r.Loop()&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + if r.Loop()&stack.PacketLoop == 0 && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { return sendTCPBatch(r, tf, data, gso, owner) } @@ -819,6 +836,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, Data: data, }) + pkt.GSOOptions = gso pkt.Hash = tf.txHash pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) @@ -826,7 +844,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } @@ -845,7 +863,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // N.B. the ordering here matches the ordering used by Linux internally // and described in the raw makeOptions function. We don't include // unnecessary cases here (post connection.) - if e.sendTSOk { + if e.SendTSOk { // Embed the timestamp if timestamp has been enabled. // // We only use the lower 32 bits of the unix time in @@ -862,7 +880,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } - if e.sackPermitted && len(sackBlocks) > 0 { + if e.SACKPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) @@ -884,7 +902,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } options := e.makeOptions(sackBlocks) err := e.sendTCP(e.route, tcpFields{ - id: e.ID, + id: e.TransportEndpointInfo.ID, ttl: e.ttl, tos: e.sendTOS, flags: flags, @@ -898,9 +916,9 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } func (e *endpoint) handleWrite() { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() next := e.drainSendQueueLocked() - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.sendData(next) } @@ -909,10 +927,10 @@ func (e *endpoint) handleWrite() { // // Precondition: e.sndBufMu must be locked. func (e *endpoint) drainSendQueueLocked() *segment { - first := e.sndQueue.Front() + first := e.sndQueueInfo.sndQueue.Front() if first != nil { - e.snd.writeList.PushBackList(&e.sndQueue) - e.sndBufInQueue = 0 + e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue) + e.sndQueueInfo.SndBufInQueue = 0 } return first } @@ -936,7 +954,7 @@ func (e *endpoint) handleClose() { e.handleWrite() // Mark send side as closed. - e.snd.closed = true + e.snd.Closed = true } // resetConnectionLocked puts the endpoint in an error state with the given @@ -958,12 +976,12 @@ func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more // information. - sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + sndWndEnd := e.snd.SndUna.Add(e.snd.SndWnd) resetSeqNum := sndWndEnd - if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { - resetSeqNum = e.snd.sndNxt + if !sndWndEnd.LessThan(e.snd.SndNxt) || e.snd.SndNxt.Size(sndWndEnd) < (1<<e.snd.SndWndScale) { + resetSeqNum = e.snd.SndNxt } - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.RcvNxt, 0) } } @@ -989,13 +1007,13 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. - e.rcvAutoParams.prevCopied = int(h.rcvWnd) - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) + e.rcvQueueInfo.rcvQueueMu.Unlock() e.setEndpointState(StateEstablished) } @@ -1026,10 +1044,15 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.TransportEndpointInfo.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) + ep = e.stack.FindTransportEndpoint( + header.IPv4ProtocolNumber, + e.TransProto, + e.TransportEndpointInfo.ID, + s.nicID, + ) } if ep == nil { replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) @@ -1108,7 +1131,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1120,7 +1145,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { break } - cont, err := e.handleSegment(s) + cont, err := e.handleSegmentLocked(s) s.decRef() if err != nil { return err @@ -1138,7 +1163,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { } // Send an ACK for all processed packets if needed. - if e.rcv.rcvNxt != e.snd.maxSentAck { + if e.rcv.RcvNxt != e.snd.MaxSentAck { e.snd.sendAck() } @@ -1147,18 +1172,21 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { return nil } -func (e *endpoint) probeSegment() { - if e.probe != nil { - e.probe(e.completeState()) +// Precondition: e.mu must be held. +func (e *endpoint) probeSegmentLocked() { + if fn := e.probe; fn != nil { + fn(e.completeStateLocked()) } } // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. - defer e.probeSegment() + defer e.probeSegmentLocked() if s.flagIsSet(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { @@ -1191,7 +1219,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. - s.window <<= e.snd.sndWndScale + s.window <<= e.snd.SndWndScale // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment @@ -1255,7 +1283,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // seg.seq = snd.nxt-1. e.keepalive.unacked++ e.keepalive.Unlock() - e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) + e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.SndNxt-1) e.resetKeepaliveTimer(false) return nil } @@ -1269,7 +1297,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.SndUna != e.snd.SndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return @@ -1340,8 +1368,24 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Reaching this point means that we successfully completed the 3-way - // handshake with our peer. - // + // handshake with our peer. The current endpoint state could be any state + // post ESTABLISHED, including CLOSED or ERROR if the endpoint processes a + // RST from the peer via the dispatcher fast path, before the loop is + // started. + if s := e.EndpointState(); !s.connected() { + switch s { + case StateClose, StateError: + // If the endpoint is in CLOSED/ERROR state, sender state has to be + // initialized if the endpoint was previously established. + if e.snd != nil { + break + } + fallthrough + default: + panic("endpoint was not established, current state " + s.String()) + } + } + // Completing the 3-way handshake is an indication that the route is valid // and the remote is reachable as the only way we can complete a handshake // is if our SYN reached the remote and their ACK reached us. @@ -1362,14 +1406,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ f func() tcpip.Error }{ { - w: &e.sndWaker, + w: &e.sndQueueInfo.sndWaker, f: func() tcpip.Error { e.handleWrite() return nil }, }, { - w: &e.sndCloseWaker, + w: &e.sndQueueInfo.sndCloseWaker, f: func() tcpip.Error { e.handleClose() return nil @@ -1403,7 +1447,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ { w: &e.newSegmentWaker, f: func() tcpip.Error { - return e.handleSegments(false /* fastPath */) + return e.handleSegmentsLocked(false /* fastPath */) }, }, { @@ -1419,11 +1463,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyMTUChanged != 0 { - e.sndBufMu.Lock() - count := e.packetTooBigCount - e.packetTooBigCount = 0 - mtu := e.sndMTU - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + count := e.sndQueueInfo.PacketTooBigCount + e.sndQueueInfo.PacketTooBigCount = 0 + mtu := e.sndQueueInfo.SndMTU + e.sndQueueInfo.sndQueueMu.Unlock() e.snd.updateMaxPayloadSize(mtu, count) } @@ -1453,7 +1497,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(false /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { return err } } @@ -1504,11 +1548,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.newSegmentWaker.Assert() } - e.rcvListMu.Lock() - if !e.rcvList.Empty() { + e.rcvQueueInfo.rcvQueueMu.Lock() + if !e.rcvQueueInfo.rcvQueue.Empty() { e.waiterQueue.Notify(waiter.ReadableEvents) } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 1975f1a44..962f1d687 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -17,6 +17,8 @@ package tcp import ( "math" "time" + + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // cubicState stores the variables related to TCP CUBIC congestion @@ -25,47 +27,12 @@ import ( // See: https://tools.ietf.org/html/rfc8312. // +stateify savable type cubicState struct { - // wLastMax is the previous wMax value. - wLastMax float64 - - // wMax is the value of the congestion window at the - // time of last congestion event. - wMax float64 - - // t denotes the time when the current congestion avoidance - // was entered. - t time.Time `state:".(unixTime)"` + stack.TCPCubicState // numCongestionEvents tracks the number of congestion events since last // RTO. numCongestionEvents int - // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as - // per RFC. - c float64 - - // k is the time period that the above function takes to increase the - // current window size to W_max if there are no further congestion - // events and is calculated using the following equation: - // - // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) - k float64 - - // beta is the CUBIC multiplication decrease factor. that is, when a - // congestion event is detected, CUBIC reduces its cwnd to - // W_cubic(0)=W_max*beta_cubic. - beta float64 - - // wC is window computed by CUBIC at time t. It's calculated using the - // formula: - // - // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) - wC float64 - - // wEst is the window computed by CUBIC at time t+RTT i.e - // W_cubic(t+RTT). - wEst float64 - s *sender } @@ -73,10 +40,12 @@ type cubicState struct { // beta and c set and t set to current time. func newCubicCC(s *sender) *cubicState { return &cubicState{ - t: time.Now(), - beta: 0.7, - c: 0.4, - s: s, + TCPCubicState: stack.TCPCubicState{ + T: time.Now(), + Beta: 0.7, + C: 0.4, + }, + s: s, } } @@ -90,10 +59,10 @@ func (c *cubicState) enterCongestionAvoidance() { // See: https://tools.ietf.org/html/rfc8312#section-4.7 & // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { - c.k = 0 - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.K = 0 + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) } } @@ -104,16 +73,16 @@ func (c *cubicState) enterCongestionAvoidance() { func (c *cubicState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := c.s.sndCwnd + packetsAcked + newcwnd := c.s.SndCwnd + packetsAcked enterCA := false - if newcwnd >= c.s.sndSsthresh { - newcwnd = c.s.sndSsthresh - c.s.sndCAAckCount = 0 + if newcwnd >= c.s.Ssthresh { + newcwnd = c.s.Ssthresh + c.s.SndCAAckCount = 0 enterCA = true } - packetsAcked -= newcwnd - c.s.sndCwnd - c.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - c.s.SndCwnd + c.s.SndCwnd = newcwnd if enterCA { c.enterCongestionAvoidance() } @@ -124,49 +93,49 @@ func (c *cubicState) updateSlowStart(packetsAcked int) int { // ACK received. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) Update(packetsAcked int) { - if c.s.sndCwnd < c.s.sndSsthresh { + if c.s.SndCwnd < c.s.Ssthresh { packetsAcked = c.updateSlowStart(packetsAcked) if packetsAcked == 0 { return } } else { c.s.rtt.Lock() - srtt := c.s.rtt.srtt + srtt := c.s.rtt.TCPRTTState.SRTT c.s.rtt.Unlock() - c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) + c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt) } } // cubicCwnd computes the CUBIC congestion window after t seconds from last // congestion event. func (c *cubicState) cubicCwnd(t float64) float64 { - return c.c*math.Pow(t, 3.0) + c.wMax + return c.C*math.Pow(t, 3.0) + c.WMax } // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.t).Seconds() + elapsed := time.Since(c.T).Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.wC = c.cubicCwnd(elapsed - c.k) + c.WC = c.cubicCwnd(elapsed - c.K) // Compute the TCP friendly estimate of the congestion window. - c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. - if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + if c.WC < c.WEst && float64(sndCwnd) < c.WEst { // TCP Friendly region of cubic. - return int(c.wEst) + return int(c.WEst) } // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.t) + srtt).Seconds() - wtRtt := c.cubicCwnd(tEst - c.k) + tEst := (time.Since(c.T) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. cwnd := float64(sndCwnd) @@ -182,9 +151,9 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() c.reduceSlowStartThreshold() @@ -193,10 +162,10 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.t = time.Now() + c.T = time.Now() c.numCongestionEvents = 0 - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() @@ -206,29 +175,29 @@ func (c *cubicState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - c.s.sndCwnd = 1 + c.s.SndCwnd = 1 } // fastConvergence implements the logic for Fast Convergence algorithm as // described in https://tools.ietf.org/html/rfc8312#section-4.6. func (c *cubicState) fastConvergence() { - if c.wMax < c.wLastMax { - c.wLastMax = c.wMax - c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + if c.WMax < c.WLastMax { + c.WLastMax = c.WMax + c.WMax = c.WMax * (1.0 + c.Beta) / 2.0 } else { - c.wLastMax = c.wMax + c.WLastMax = c.WMax } // Recompute k as wMax may have changed. - c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) + c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C) } // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.t = time.Now() + c.T = time.Now() } // reduceSlowStartThreshold returns new SsThresh as described in // https://tools.ietf.org/html/rfc8312#section-4.7. func (c *cubicState) reduceSlowStartThreshold() { - c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) + c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0)) } diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 21162f01a..512053a04 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -116,7 +116,7 @@ func (p *processor) start(wg *sync.WaitGroup) { if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { // If the endpoint is in a connected state then we do direct delivery // to ensure low latency and avoid scheduler interactions. - switch err := ep.handleSegments(true /* fastPath */); { + switch err := ep.handleSegmentsLocked(true /* fastPath */); { case err != nil: // Send any active resets if required. ep.resetConnectionLocked(err) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f6a16f96e..f148d505d 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -37,8 +38,8 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -49,8 +50,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -156,8 +157,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -391,7 +392,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -525,7 +526,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -565,17 +566,15 @@ func TestV4AcceptOnV4(t *testing.T) { } func testV4ListenClose(t *testing.T, c *context.Context) { - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } - const n = uint16(32) + const n = 32 // Start listening. - if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + if err := c.EP.Listen(n); err != nil { t.Fatalf("Listen failed: %v", err) } @@ -591,9 +590,9 @@ func testV4ListenClose(t *testing.T, c *context.Context) { }) } - // Each of these ACK's will cause a syn-cookie based connection to be + // Each of these ACKs will cause a syn-cookie based connection to be // accepted and delivered to the listening endpoint. - for i := uint16(0); i < n; i++ { + for i := 0; i < n; i++ { b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss := seqnum.Value(tcp.SequenceNumber()) @@ -613,7 +612,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c5daba232..90edcfba6 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "container/list" "encoding/binary" "fmt" "io" @@ -190,42 +191,6 @@ type SACKInfo struct { NumBlocks int } -// rcvBufAutoTuneParams are used to hold state variables to compute -// the auto tuned recv buffer size. -// -// +stateify savable -type rcvBufAutoTuneParams struct { - // measureTime is the time at which the current measurement - // was started. - measureTime time.Time `state:".(unixTime)"` - - // copied is the number of bytes copied out of the receive - // buffers since this measure began. - copied int - - // prevCopied is the number of bytes copied out of the receive - // buffers in the previous RTT period. - prevCopied int - - // rtt is the non-smoothed minimum RTT as measured by observing the time - // between when a byte is first acknowledged and the receipt of data - // that is at least one window beyond the sequence number that was - // acknowledged. - rtt time.Duration - - // rttMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - rttMeasureSeqNumber seqnum.Value - - // rttMeasureTime is the absolute time at which the current rtt - // measurement period began. - rttMeasureTime time.Time `state:".(unixTime)"` - - // disabled is true if an explicit receive buffer is set for the - // endpoint. - disabled bool -} - // ReceiveErrors collect segment receive errors within transport layer. type ReceiveErrors struct { tcpip.ReceiveErrors @@ -246,7 +211,7 @@ type ReceiveErrors struct { ListenOverflowAckDrop tcpip.StatCounter // ZeroRcvWindowState is the number of times we advertised - // a zero receive window when rcvList is full. + // a zero receive window when rcvQueue is full. ZeroRcvWindowState tcpip.StatCounter // WantZeroWindow is the number of times we wanted to advertise a @@ -309,18 +274,45 @@ type Stats struct { // marker interface. func (*Stats) IsEndpointStats() {} -// EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. This exists to allow tcp-only state to -// be exposed. +// sndQueueInfo implements a send queue. // // +stateify savable -type EndpointInfo struct { - stack.TransportEndpointInfo +type sndQueueInfo struct { + sndQueueMu sync.Mutex `state:"nosave"` + stack.TCPSndBufState + + // sndQueue holds segments that are ready to be sent. + sndQueue segmentList `state:"wait"` + + // sndWaker is used to signal the protocol goroutine when segments are + // added to the `sndQueue`. + sndWaker sleep.Waker `state:"manual"` + + // sndCloseWaker is used to notify the protocol goroutine when the send + // side is closed. + sndCloseWaker sleep.Waker `state:"manual"` } -// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo -// marker interface. -func (*EndpointInfo) IsEndpointInfo() {} +// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata. +// +// +stateify savable +type rcvQueueInfo struct { + rcvQueueMu sync.Mutex `state:"nosave"` + stack.TCPRcvBufState + + // rcvQueue is the queue for ready-for-delivery segments. This struct's + // mutex must be held in order append segments to list. + rcvQueue segmentList `state:"wait"` +} + +// +stateify savable +type accepted struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` + cap int +} // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to @@ -337,9 +329,9 @@ func (*EndpointInfo) IsEndpointInfo() {} // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects acceptedChan. -// e.rcvListMu -> Protects the rcvList and associated fields. -// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.acceptMu -> protects accepted. +// e.rcvQueueMu -> Protects e.rcvQueue and associated fields. +// e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. // // LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different @@ -362,7 +354,8 @@ func (*EndpointInfo) IsEndpointInfo() {} // // +stateify savable type endpoint struct { - EndpointInfo + stack.TCPEndpointStateInner + stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the @@ -395,38 +388,23 @@ type endpoint struct { // rcvReadMu synchronizes calls to Read. // - // mu and rcvListMu are temporarily released during data copying. rcvReadMu + // mu and rcvQueueMu are temporarily released during data copying. rcvReadMu // must be held during each read to ensure atomicity, so that multiple reads // do not interleave. // // rcvReadMu should be held before holding mu. rcvReadMu sync.Mutex `state:"nosave"` - // rcvListMu synchronizes access to rcvList. - // - // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - - // rcvList is the queue for ready-for-delivery segments. - // - // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data - // and removing segments from list. A range of segment can be determined, then - // temporarily release mu and rcvListMu while processing the segment range. - // This allows new segments to be appended to the list while processing. - // - // rcvListMu must be held to append segments to list. - rcvList segmentList `state:"wait"` - rcvClosed bool - // rcvBufSize is the total size of the receive buffer. - rcvBufSize int - // rcvBufUsed is the actual number of payload bytes held in the receive buffer - // not counting any overheads of the segments itself. NOTE: This will always - // be strictly <= rcvMemUsed below. - rcvBufUsed int - rcvAutoParams rcvBufAutoTuneParams + // rcvQueueInfo holds the implementation of the endpoint's receive buffer. + // The data within rcvQueueInfo should only be accessed while rcvReadMu, mu, + // and rcvQueueMu are held, in that stated order. While processing the segment + // range, you can determine a range and then temporarily release mu and + // rcvQueueMu, which allows new segments to be appended to the queue while + // processing. + rcvQueueInfo rcvQueueInfo // rcvMemUsed tracks the total amount of memory in use by received segments - // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to // compute the window and the actual available buffer space. This is distinct // from rcvBufUsed above which is the actual number of payload bytes held in // the buffer not including any segment overheads. @@ -488,33 +466,16 @@ type endpoint struct { // also true, and they're both protected by the mutex. workerCleanup bool - // sendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1 - sendTSOk bool - - // recentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - recentTS uint32 - - // recentTSTime is the unix time when we updated recentTS last. + // recentTSTime is the unix time when we last updated + // TCPEndpointStateInner.RecentTS. recentTSTime time.Time `state:".(unixTime)"` - // tsOffset is a randomized offset added to the value of the - // TSVal field in the timestamp option. - tsOffset uint32 - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags // tcpRecovery is the loss deteoction algorithm used by TCP. tcpRecovery tcpip.TCPRecovery - // sackPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - sackPermitted bool - // sack holds TCP SACK related information for this endpoint. sack SACKInfo @@ -550,32 +511,13 @@ type endpoint struct { // this value. windowClamp uint32 - // The following fields are used to manage the send buffer. When - // segments are ready to be sent, they are added to sndQueue and the - // protocol goroutine is signaled via sndWaker. - // - // When the send side is closed, the protocol goroutine is notified via - // sndCloseWaker, and sndClosed is set to true. - sndBufMu sync.Mutex `state:"nosave"` - sndBufUsed int - sndClosed bool - sndBufInQueue seqnum.Size - sndQueue segmentList `state:"wait"` - sndWaker sleep.Waker `state:"manual"` - sndCloseWaker sleep.Waker `state:"manual"` + // sndQueueInfo contains the implementation of the endpoint's send queue. + sndQueueInfo sndQueueInfo // cc stores the name of the Congestion Control algorithm to use for // this endpoint. cc tcpip.CongestionControlOption - // The following are used when a "packet too big" control packet is - // received. They are protected by sndBufMu. They are used to - // communicate to the main protocol goroutine how many such control - // messages have been received since the last notification was processed - // and what was the smallest MTU seen. - packetTooBigCount int - sndMTU int - // newSegmentWaker is used to indicate to the protocol goroutine that // it needs to wake up and handle new segments queued to it. newSegmentWaker sleep.Waker `state:"manual"` @@ -607,33 +549,26 @@ type endpoint struct { // listener. deferAccept time.Duration - // pendingAccepted is a synchronization primitive used to track number - // of connections that are queued up to be delivered to the accepted - // channel. We use this to ensure that all goroutines blocked on writing - // to the acceptedChan below terminate before we close acceptedChan. + // pendingAccepted tracks connections queued to be accepted. It is used to + // ensure such queued connections are terminated before the accepted queue is + // marked closed (by setting its capacity to zero). pendingAccepted sync.WaitGroup `state:"nosave"` - // acceptMu protects acceptedChan. + // acceptMu protects accepted. acceptMu sync.Mutex `state:"nosave"` // acceptCond is a condition variable that can be used to block on when - // acceptedChan is full and an endpoint is ready to be delivered. - // - // This condition variable is required because just blocking on sending - // to acceptedChan does not work in cases where endpoint.Listen is - // called twice with different backlog values. In such cases the channel - // is closed and a new one created. Any pending goroutines blocking on - // the write to the channel will panic. + // accepted is full and an endpoint is ready to be delivered. // // We use this condition variable to block/unblock goroutines which // tried to deliver an endpoint but couldn't because accept backlog was // full ( See: endpoint.deliverAccepted ). acceptCond *sync.Cond `state:"nosave"` - // acceptedChan is used by a listening endpoint protocol goroutine to + // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. - acceptedChan chan *endpoint `state:".([]*endpoint)"` + accepted accepted // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -664,7 +599,7 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - gso *stack.GSO + gso stack.GSO // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` @@ -779,7 +714,7 @@ func (e *endpoint) UnlockUser() { switch e.EndpointState() { case StateEstablished: - if err := e.handleSegments(true /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(true /* fastPath */); err != nil { e.notifyProtocolGoroutine(notifyTickleWorker) } default: @@ -839,13 +774,13 @@ func (e *endpoint) EndpointState() EndpointState { // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { - e.recentTS = recentTS + e.RecentTS = recentTS e.recentTSTime = time.Now() } // recentTimestamp returns the value of the recentTS field. func (e *endpoint) recentTimestamp() uint32 { - return e.recentTS + return e.RecentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -865,16 +800,17 @@ type keepalive struct { func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ stack: s, - EndpointInfo: EndpointInfo{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.TCPProtocolNumber, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + sndQueueInfo: sndQueueInfo{ + TCPSndBufState: stack.TCPSndBufState{ + SndMTU: int(math.MaxInt32), }, }, waiterQueue: waiterQueue, state: StateInitial, - rcvBufSize: DefaultReceiveBufferSize, - sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -886,10 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) + e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -898,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - e.rcvBufSize = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } var cs tcpip.CongestionControlOption @@ -908,7 +845,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { - e.rcvAutoParams.disabled = !bool(mrb) + e.rcvQueueInfo.RcvAutoParams.Disabled = !bool(mrb) } var de tcpip.TCPDelayEnabled @@ -933,7 +870,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.tsOffset = timeStampOffset() + e.TSOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(&e.keepalive.waker) @@ -959,10 +896,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result = mask case StateListen: - // Check if there's anything in the accepted channel. + // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if len(e.acceptedChan) > 0 { + if e.accepted.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -971,21 +908,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.WritableEvents) != 0 { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() sndBufSize := e.getSendBufferSize() - if e.sndClosed || e.sndBufUsed < sndBufSize { + if e.sndQueueInfo.SndClosed || e.sndQueueInfo.SndBufUsed < sndBufSize { result |= waiter.WritableEvents } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() } // Determine if the endpoint is readable if requested. if (mask & waiter.ReadableEvents) != 0 { - e.rcvListMu.Lock() - if e.rcvBufUsed > 0 || e.rcvClosed { + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvBufUsed > 0 || e.rcvQueueInfo.RcvClosed { result |= waiter.ReadableEvents } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() } } @@ -1093,15 +1030,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1145,22 +1082,22 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptMu.Unlock() + acceptedCopy := e.accepted + e.accepted = accepted{} + e.acceptMu.Unlock() + + if acceptedCopy == (accepted{}) { return } - close(e.acceptedChan) - ch := e.acceptedChan - e.acceptedChan = nil + e.acceptCond.Broadcast() - e.acceptMu.Unlock() // Reset all connections that are waiting to be accepted. - for n := range ch { - n.notifyProtocolGoroutine(notifyReset) + for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } // Wait for reset of all endpoints that are still waiting to be delivered to - // the now closed acceptedChan. + // the now closed accepted. e.pendingAccepted.Wait() } @@ -1176,7 +1113,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1184,8 +1121,8 @@ func (e *endpoint) cleanupLocked() { portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1247,19 +1184,19 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - if e.rcvAutoParams.disabled { - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvAutoParams.Disabled { + e.rcvQueueInfo.rcvQueueMu.Unlock() return } now := time.Now() - if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { - e.rcvAutoParams.copied += copied - e.rcvListMu.Unlock() + if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { + e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied + e.rcvQueueInfo.rcvQueueMu.Unlock() return } - prevRTTCopied := e.rcvAutoParams.copied + copied - prevCopied := e.rcvAutoParams.prevCopied + prevRTTCopied := e.rcvQueueInfo.RcvAutoParams.CopiedBytes + copied + prevCopied := e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes rcvWnd := 0 if prevRTTCopied > prevCopied { // The minimal receive window based on what was copied by the app @@ -1291,24 +1228,25 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // We do not adjust downwards as that can cause the receiver to // reject valid data that might already be in flight as the // acceptable window will shrink. - if rcvWnd > e.rcvBufSize { - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = rcvWnd - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + rcvBufSize := int(e.ops.GetReceiveBufferSize()) + if rcvWnd > rcvBufSize { + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd)) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } - // We only update prevCopied when we grow the buffer because in cases - // where prevCopied > prevRTTCopied the existing buffer is already big + // We only update PrevCopiedBytes when we grow the buffer because in cases + // where PrevCopiedBytes > prevRTTCopied the existing buffer is already big // enough to handle the current rate and we don't need to do any // adjustments. - e.rcvAutoParams.prevCopied = prevRTTCopied + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = prevRTTCopied } - e.rcvAutoParams.measureTime = now - e.rcvAutoParams.copied = 0 - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.MeasureTime = now + e.rcvQueueInfo.RcvAutoParams.CopiedBytes = 0 + e.rcvQueueInfo.rcvQueueMu.Unlock() } // SetOwner implements tcpip.Endpoint.SetOwner. @@ -1342,6 +1280,12 @@ func (e *endpoint) LastError() tcpip.Error { return e.lastErrorLocked() } +// LastErrorLocked reads and clears lastError with e.mu held. +// Only to be used in tests. +func (e *endpoint) LastErrorLocked() tcpip.Error { + return e.lastErrorLocked() +} + // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() @@ -1357,7 +1301,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult defer e.rcvReadMu.Unlock() // N.B. Here we get a range of segments to be processed. It is safe to not - // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we + // hold rcvQueueMu when processing, since we hold rcvReadMu to ensure only we // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { @@ -1429,10 +1373,10 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - bufUsed := e.rcvBufUsed + bufUsed := e.rcvQueueInfo.RcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { if s == StateError { if err := e.hardErrorLocked(); err != nil { @@ -1444,14 +1388,14 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { return nil, nil, &tcpip.ErrNotConnected{} } - if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.EndpointState().connected() { + if e.rcvQueueInfo.RcvBufUsed == 0 { + if e.rcvQueueInfo.RcvClosed || !e.EndpointState().connected() { return nil, nil, &tcpip.ErrClosedForReceive{} } return nil, nil, &tcpip.ErrWouldBlock{} } - return e.rcvList.Front(), e.rcvList.Back(), nil + return e.rcvQueueInfo.rcvQueue.Front(), e.rcvQueueInfo.rcvQueue.Back(), nil } // commitRead commits a read of done bytes and returns the next non-empty @@ -1467,39 +1411,39 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { func (e *endpoint) commitRead(done int) *segment { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() memDelta := 0 - s := e.rcvList.Front() + s := e.rcvQueueInfo.rcvQueue.Front() for s != nil && s.data.Size() == 0 { - e.rcvList.Remove(s) + e.rcvQueueInfo.rcvQueue.Remove(s) // Memory is only considered released when the whole segment has been // read. memDelta += s.segMemSize() s.decRef() - s = e.rcvList.Front() + s = e.rcvQueueInfo.rcvQueue.Front() } - e.rcvBufUsed -= done + e.rcvQueueInfo.RcvBufUsed -= done if memDelta > 0 { // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } - return e.rcvList.Front() + return e.rcvQueueInfo.rcvQueue.Front() } // isEndpointWritableLocked checks if a given endpoint is writable // and also returns the number of bytes that can be written at this // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. -// Caller must hold e.mu and e.sndBufMu +// Caller must hold e.mu and e.sndQueueMu func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { @@ -1519,12 +1463,12 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { } // Check if the connection has already been closed for sends. - if e.sndClosed { + if e.sndQueueInfo.SndClosed { return 0, &tcpip.ErrClosedForSend{} } sndBufSize := e.getSendBufferSize() - avail := sndBufSize - e.sndBufUsed + avail := sndBufSize - e.sndQueueInfo.SndBufUsed if avail <= 0 { return 0, &tcpip.ErrWouldBlock{} } @@ -1541,8 +1485,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp defer e.UnlockUser() nextSeg, n, err := func() (*segment, int, tcpip.Error) { - e.sndBufMu.Lock() - defer e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + defer e.sndQueueInfo.sndQueueMu.Unlock() avail, err := e.isEndpointWritableLocked() if err != nil { @@ -1557,8 +1501,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // available buffer space to be consumed by some other caller while we // are copying data in. if !opts.Atomic { - e.sndBufMu.Unlock() - defer e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Unlock() + defer e.sndQueueInfo.sndQueueMu.Lock() e.UnlockUser() defer e.LockUser() @@ -1600,10 +1544,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + e.sndQueueInfo.SndBufUsed += len(v) + e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) + e.sndQueueInfo.sndQueue.PushBack(s) return e.drainSendQueueLocked(), len(v), nil }() @@ -1618,11 +1562,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. -// Precondition: e.mu and e.rcvListMu must be held. -func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { - wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) - maxWindow := wndFromSpace(e.rcvBufSize) - wndFromUsedBytes := maxWindow - e.rcvBufUsed +// Precondition: e.mu and e.rcvQueueMu must be held. +func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) { + wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + maxWindow := wndFromSpace(rcvBufSize) + wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in // cases where we receive a lot of small segments the segment overhead is a @@ -1640,11 +1584,11 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { return seqnum.Size(newWnd) } -// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu. +// selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu. func (e *endpoint) selectWindow() (wnd seqnum.Size) { - e.rcvListMu.Lock() - wnd = e.selectWindowLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize())) + e.rcvQueueInfo.rcvQueueMu.Unlock() return wnd } @@ -1662,9 +1606,9 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // above will be true if the new window is >= ACK threshold and false // otherwise. // -// Precondition: e.mu and e.rcvListMu must be held. -func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := int(e.selectWindowLocked()) +// Precondition: e.mu and e.rcvQueueMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) { + newAvail := int(e.selectWindowLocked(rcvBufSize)) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 @@ -1673,7 +1617,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // rcvBufFraction is the inverse of the fraction of receive buffer size that // is used to decide if the available buffer space is now above it. const rcvBufFraction = 2 - if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold { threshold = wndThreshold } switch { @@ -1700,7 +1644,7 @@ func (e *endpoint) OnReusePortSet(v bool) { } // OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. -func (e *endpoint) OnKeepAliveSet(v bool) { +func (e *endpoint) OnKeepAliveSet(bool) { e.notifyProtocolGoroutine(notifyKeepaliveChanged) } @@ -1708,7 +1652,7 @@ func (e *endpoint) OnKeepAliveSet(v bool) { func (e *endpoint) OnDelayOptionSet(v bool) { if !v { // Handle delayed data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1716,7 +1660,7 @@ func (e *endpoint) OnDelayOptionSet(v bool) { func (e *endpoint) OnCorkOptionSet(v bool) { if !v { // Handle the corked data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1724,6 +1668,37 @@ func (e *endpoint) getSendBufferSize() int { return int(e.ops.GetSendBufferSize()) } +// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize. +func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) { + e.LockUser() + e.rcvQueueInfo.rcvQueueMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.RcvWndScale + } + if rcvBufSz>>scale == 0 { + rcvBufSz = 1 << scale + } + + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz))) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz))) + e.rcvQueueInfo.RcvAutoParams.Disabled = true + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + e.rcvQueueInfo.rcvQueueMu.Unlock() + e.UnlockUser() + return rcvBufSz +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -1767,56 +1742,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return &tcpip.ErrNotSupported{} } - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) - } - - if v > rs.Max { - v = rs.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < rs.Min { - v = rs.Min - } - } else { - v = math.MaxInt32 - } - - e.LockUser() - e.rcvListMu.Lock() - - // Make sure the receive buffer size allows us to send a - // non-zero window size. - scale := uint8(0) - if e.rcv != nil { - scale = e.rcv.rcvWndScale - } - if v>>scale == 0 { - v = 1 << scale - } - - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = v - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - - e.rcvAutoParams.disabled = true - - // Immediately send an ACK to uncork the sender silly window - // syndrome prevetion, when our available space grows above aMSS - // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) - } - - e.rcvListMu.Unlock() - e.UnlockUser() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1959,10 +1884,10 @@ func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { return 0, &tcpip.ErrInvalidEndpointState{} } - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - return e.rcvBufUsed, nil + return e.rcvQueueInfo.RcvBufUsed, nil } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -2002,12 +1927,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - v := e.rcvBufSize - e.rcvListMu.Unlock() - return v, nil - case tcpip.TTLOption: e.LockUser() v := int(e.ttl) @@ -2043,15 +1962,15 @@ func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { // the connection did not send and receive data, then RTT will // be zero. snd.rtt.Lock() - info.RTT = snd.rtt.srtt - info.RTTVar = snd.rtt.rttvar + info.RTT = snd.rtt.TCPRTTState.SRTT + info.RTTVar = snd.rtt.TCPRTTState.RTTVar snd.rtt.Unlock() - info.RTO = snd.rto + info.RTO = snd.RTO info.CcState = snd.state - info.SndSsthresh = uint32(snd.sndSsthresh) - info.SndCwnd = uint32(snd.sndCwnd) - info.ReorderSeen = snd.rc.reorderSeen + info.SndSsthresh = uint32(snd.Ssthresh) + info.SndCwnd = uint32(snd.SndCwnd) + info.ReorderSeen = snd.rc.Reord } e.UnlockUser() return info @@ -2096,7 +2015,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) e.UnlockUser() if err != nil { return err @@ -2204,20 +2123,20 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.TransportEndpointInfo.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.ID.LocalAddress = r.LocalAddress() - e.ID.RemoteAddress = r.RemoteAddress() - e.ID.RemotePort = addr.Port + e.TransportEndpointInfo.ID.LocalAddress = r.LocalAddress() + e.TransportEndpointInfo.ID.RemoteAddress = r.RemoteAddress() + e.TransportEndpointInfo.ID.RemotePort = addr.Port - if e.ID.LocalPort != 0 { + if e.TransportEndpointInfo.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2226,19 +2145,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + sameAddr := e.TransportEndpointInfo.ID.LocalAddress == e.TransportEndpointInfo.ID.RemoteAddress // Calculate a port offset based on the destination IP/port and // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.Seed()) - h.Write([]byte(e.ID.LocalAddress)) - h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) - h.Write(portBuf) - portOffset := uint16(h.Sum32()) + + h := jenkins.Sum32(e.stack.Seed()) + for _, s := range [][]byte{ + []byte(e.ID.LocalAddress), + []byte(e.ID.RemoteAddress), + portBuf, + } { + // Per io.Writer.Write: + // + // Write must return a non-nil error if it returns n < len(p). + if _, err := h.Write(s); err != nil { + panic(err) + } + } + portOffset := h.Sum32() var twReuse tcpip.TCPTimeWaitReuseOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { @@ -2249,21 +2178,21 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { switch netProto { case header.IPv4ProtocolNumber: - reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + reuse = header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.LocalAddress) && header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.RemoteAddress) case header.IPv6ProtocolNumber: - reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + reuse = e.TransportEndpointInfo.ID.LocalAddress == header.IPv6Loopback && e.TransportEndpointInfo.ID.RemoteAddress == header.IPv6Loopback } } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { - if sameAddr && p == e.ID.RemotePort { + if sameAddr && p == e.TransportEndpointInfo.ID.RemotePort { return false, nil } portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2273,7 +2202,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } - transEPID := e.ID + transEPID := e.TransportEndpointInfo.ID transEPID.LocalPort = p // Check if an endpoint is registered with demuxer in TIME-WAIT and if // we can reuse it. If we can't find a transport endpoint then we just @@ -2310,7 +2239,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2321,13 +2250,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } } - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2342,13 +2271,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // Port picking successful. Save the details of // the selected port. - e.ID = id + e.TransportEndpointInfo.ID = id e.isPortReserved = true e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil }); err != nil { + e.stack.Stats().TCP.FailedPortReservations.Increment() return err } } @@ -2367,10 +2297,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // connection setting here. if !handshake { e.segmentQueue.mu.Lock() - for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.ID - e.sndWaker.Assert() + s.id = e.TransportEndpointInfo.ID + e.sndQueueInfo.sndWaker.Assert() } } e.segmentQueue.mu.Unlock() @@ -2412,10 +2342,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for read. if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. - e.rcvListMu.Lock() - e.rcvClosed = true - rcvBufUsed := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + rcvBufUsed := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() // If we're fully closed and we have unread data we need to abort // the connection with a RST. @@ -2429,10 +2359,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for write. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - e.sndBufMu.Lock() - if e.sndClosed { + e.sndQueueInfo.sndQueueMu.Lock() + if e.sndQueueInfo.SndClosed { // Already closed. - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() if e.EndpointState() == StateTimeWait { return &tcpip.ErrNotConnected{} } @@ -2440,12 +2370,12 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.ID, nil) - e.sndQueue.PushBack(s) - e.sndBufInQueue++ + s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) + e.sndQueueInfo.sndQueue.PushBack(s) + e.sndQueueInfo.SndBufInQueue++ // Mark endpoint as closed. - e.sndClosed = true - e.sndBufMu.Unlock() + e.sndQueueInfo.SndClosed = true + e.sndQueueInfo.sndQueueMu.Unlock() e.handleClose() } @@ -2458,9 +2388,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // // By not removing this endpoint from the demuxer mapping, we // ensure that any other bind to the same port fails, as on Linux. - e.rcvListMu.Lock() - e.rcvClosed = true - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + e.rcvQueueInfo.rcvQueueMu.Unlock() e.closePendingAcceptableConnectionsLocked() // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) @@ -2491,28 +2421,20 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.acceptedChan == nil { + if e.accepted == (accepted{}) { // listen is called after shutdown. - e.acceptedChan = make(chan *endpoint, backlog) + e.accepted.cap = backlog e.shutdownFlags = 0 - e.rcvListMu.Lock() - e.rcvClosed = false - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() } else { - // Adjust the size of the channel iff we can fix + // Adjust the size of the backlog iff we can fit // existing pending connections into the new one. - if len(e.acceptedChan) > backlog { + if e.accepted.endpoints.Len() > backlog { return &tcpip.ErrInvalidEndpointState{} } - if cap(e.acceptedChan) == backlog { - return nil - } - origChan := e.acceptedChan - e.acceptedChan = make(chan *endpoint, backlog) - close(origChan) - for ep := range origChan { - e.acceptedChan <- ep - } + e.accepted.cap = backlog } // Notify any blocked goroutines that they can attempt to @@ -2538,19 +2460,19 @@ func (e *endpoint) listen(backlog int) tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) - // The channel may be non-nil when we're restoring the endpoint, and it + // The queue may be non-zero when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptedChan = make(chan *endpoint, backlog) + if e.accepted == (accepted{}) { + e.accepted.cap = backlog } e.acceptMu.Unlock() @@ -2578,24 +2500,25 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { return nil, nil, &tcpip.ErrInvalidEndpointState{} } // Get the new accepted endpoint. - e.acceptMu.Lock() - defer e.acceptMu.Unlock() var n *endpoint - select { - case n = <-e.acceptedChan: - e.acceptCond.Signal() - default: + e.acceptMu.Lock() + if element := e.accepted.endpoints.Front(); element != nil { + n = e.accepted.endpoints.Remove(element).(*endpoint) + } + e.acceptMu.Unlock() + if n == nil { return nil, nil, &tcpip.ErrWouldBlock{} } + e.acceptCond.Signal() if peerAddr != nil { *peerAddr = n.getRemoteAddress() } @@ -2645,7 +2568,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { if nic == 0 { return &tcpip.ErrBadLocalAddress{} } - e.ID.LocalAddress = addr.Addr + e.TransportEndpointInfo.ID.LocalAddress = addr.Addr } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) @@ -2659,7 +2582,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { Dest: tcpip.FullAddress{}, } port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a // listening endpoint bound with the same id and portFlags and bindToDevice @@ -2675,6 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { return true, nil }) if err != nil { + e.stack.Stats().TCP.FailedPortReservations.Increment() return err } @@ -2684,7 +2608,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { e.boundNICID = nic e.isPortReserved = true e.effectiveNetProtos = netProtos - e.ID.LocalPort = port + e.TransportEndpointInfo.ID.LocalPort = port // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2698,8 +2622,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { defer e.UnlockUser() return tcpip.FullAddress{ - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -2718,8 +2642,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, NIC: e.boundNICID, } } @@ -2758,13 +2682,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p Payload: pkt.Data().AsRange().ToOwnedView(), Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -2777,12 +2701,12 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p // HandleError implements stack.TransportEndpoint. func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { handlePacketTooBig := func(mtu uint32) { - e.sndBufMu.Lock() - e.packetTooBigCount++ - if v := int(mtu); v < e.sndMTU { - e.sndMTU = v + e.sndQueueInfo.sndQueueMu.Lock() + e.sndQueueInfo.PacketTooBigCount++ + if v := int(mtu); v < e.sndQueueInfo.SndMTU { + e.sndQueueInfo.SndMTU = v } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) } @@ -2801,14 +2725,14 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { sendBufferSize := e.getSendBufferSize() - e.sndBufMu.Lock() - notify := e.sndBufUsed >= sendBufferSize>>1 - e.sndBufUsed -= v + e.sndQueueInfo.sndQueueMu.Lock() + notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1 + e.sndQueueInfo.SndBufUsed -= v // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 - e.sndBufMu.Unlock() + notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + e.sndQueueInfo.sndQueueMu.Unlock() if notify { e.waiterQueue.Notify(waiter.WritableEvents) @@ -2819,58 +2743,50 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() if s != nil { - e.rcvBufUsed += s.payloadSize() + e.rcvQueueInfo.RcvBufUsed += s.payloadSize() s.incRef() - e.rcvList.PushBack(s) + e.rcvQueueInfo.rcvQueue.PushBack(s) } else { - e.rcvClosed = true + e.rcvQueueInfo.RcvClosed = true } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() e.waiterQueue.Notify(waiter.ReadableEvents) } // receiveBufferAvailableLocked calculates how many bytes are still available // in the receive buffer. -// rcvListMu must be held when this function is called. -func (e *endpoint) receiveBufferAvailableLocked() int { +// rcvQueueMu must be held when this function is called. +func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int { // We may use more bytes than the buffer size when the receive buffer // shrinks. memUsed := e.receiveMemUsed() - if memUsed >= e.rcvBufSize { + if memUsed >= rcvBufSize { return 0 } - return e.rcvBufSize - memUsed + return rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the // receive buffer based on the actual memory used by all segments held in // receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { - e.rcvListMu.Lock() - available := e.receiveBufferAvailableLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize())) + e.rcvQueueInfo.rcvQueueMu.Unlock() return available } // receiveBufferUsed returns the amount of in-use receive buffer. func (e *endpoint) receiveBufferUsed() int { - e.rcvListMu.Lock() - used := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + used := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() return used } -// receiveBufferSize returns the current size of the receive buffer. -func (e *endpoint) receiveBufferSize() int { - e.rcvListMu.Lock() - size := e.rcvBufSize - e.rcvListMu.Unlock() - return size -} - // receiveMemUsed returns the total memory in use by segments held by this // endpoint. func (e *endpoint) receiveMemUsed() int { @@ -2899,11 +2815,11 @@ func (e *endpoint) maxReceiveBufferSize() int { // receiveBuffer otherwise we use the max permissible receive buffer size to // compute the scale. func (e *endpoint) rcvWndScaleForHandshake() int { - bufSizeForScale := e.receiveBufferSize() + bufSizeForScale := e.ops.GetReceiveBufferSize() - e.rcvListMu.Lock() - autoTuningDisabled := e.rcvAutoParams.disabled - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled + e.rcvQueueInfo.rcvQueueMu.Unlock() if autoTuningDisabled { return FindWndScale(seqnum.Size(bufSizeForScale)) } @@ -2914,7 +2830,7 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + if e.SendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { e.setRecentTimestamp(tsVal) } } @@ -2924,7 +2840,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, // initializes the recentTS with the value provided in synOpts.TSval. func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { - e.sendTSOk = true + e.SendTSOk = true e.setRecentTimestamp(synOpts.TSVal) } } @@ -2932,7 +2848,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.tsOffset) + return tcpTimeStamp(time.Now(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is @@ -2971,7 +2887,7 @@ func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { return } if bool(v) && synOpts.SACKPermitted { - e.sackPermitted = true + e.SACKPermitted = true } } @@ -2985,144 +2901,70 @@ func (e *endpoint) maxOptionSize() (size int) { return size } -// completeState makes a full copy of the endpoint and returns it. This is used -// before invoking the probe. The state returned may not be fully consistent if -// there are intervening syscalls when the state is being copied. -func (e *endpoint) completeState() stack.TCPEndpointState { - var s stack.TCPEndpointState - s.SegTime = time.Now() - - // Copy EndpointID. - s.ID = stack.TCPEndpointID(e.ID) - - // Copy endpoint rcv state. - e.rcvListMu.Lock() - s.RcvBufSize = e.rcvBufSize - s.RcvBufUsed = e.rcvBufUsed - s.RcvClosed = e.rcvClosed - s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime - s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied - s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied - s.RcvAutoParams.RTT = e.rcvAutoParams.rtt - s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber - s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime - s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled - e.rcvListMu.Unlock() - - // Endpoint TCP Option state. - s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTimestamp() - s.TSOffset = e.tsOffset - s.SACKPermitted = e.sackPermitted +// completeStateLocked makes a full copy of the endpoint and returns it. This is +// used before invoking the probe. +// +// Precondition: e.mu must be held. +func (e *endpoint) completeStateLocked() stack.TCPEndpointState { + s := stack.TCPEndpointState{ + TCPEndpointStateInner: e.TCPEndpointStateInner, + ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), + SegTime: time.Now(), + Receiver: e.rcv.TCPReceiverState, + Sender: e.snd.TCPSenderState, + } + + sndBufSize := e.getSendBufferSize() + // Copy the send buffer atomically. + e.sndQueueInfo.sndQueueMu.Lock() + s.SndBufState = e.sndQueueInfo.TCPSndBufState + s.SndBufState.SndBufSize = sndBufSize + e.sndQueueInfo.sndQueueMu.Unlock() + + // Copy the receive buffer atomically. + e.rcvQueueInfo.rcvQueueMu.Lock() + s.RcvBufState = e.rcvQueueInfo.TCPRcvBufState + e.rcvQueueInfo.rcvQueueMu.Unlock() + + // Copy the endpoint TCP Option state. s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() - // Copy endpoint send state. - sndBufSize := e.getSendBufferSize() - e.sndBufMu.Lock() - s.SndBufSize = sndBufSize - s.SndBufUsed = e.sndBufUsed - s.SndClosed = e.sndClosed - s.SndBufInQueue = e.sndBufInQueue - s.PacketTooBigCount = e.packetTooBigCount - s.SndMTU = e.sndMTU - e.sndBufMu.Unlock() - - // Copy receiver state. - s.Receiver = stack.TCPReceiverState{ - RcvNxt: e.rcv.rcvNxt, - RcvAcc: e.rcv.rcvAcc, - RcvWndScale: e.rcv.rcvWndScale, - PendingBufUsed: e.rcv.pendingBufUsed, - } - - // Copy sender state. - s.Sender = stack.TCPSenderState{ - LastSendTime: e.snd.lastSendTime, - DupAckCount: e.snd.dupAckCount, - FastRecovery: stack.TCPFastRecoveryState{ - Active: e.snd.fr.active, - First: e.snd.fr.first, - Last: e.snd.fr.last, - MaxCwnd: e.snd.fr.maxCwnd, - HighRxt: e.snd.fr.highRxt, - RescueRxt: e.snd.fr.rescueRxt, - }, - SndCwnd: e.snd.sndCwnd, - Ssthresh: e.snd.sndSsthresh, - SndCAAckCount: e.snd.sndCAAckCount, - Outstanding: e.snd.outstanding, - SackedOut: e.snd.sackedOut, - SndWnd: e.snd.sndWnd, - SndUna: e.snd.sndUna, - SndNxt: e.snd.sndNxt, - RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, - RTTMeasureTime: e.snd.rttMeasureTime, - Closed: e.snd.closed, - RTO: e.snd.rto, - MaxPayloadSize: e.snd.maxPayloadSize, - SndWndScale: e.snd.sndWndScale, - MaxSentAck: e.snd.maxSentAck, - } e.snd.rtt.Lock() - s.Sender.SRTT = e.snd.rtt.srtt - s.Sender.SRTTInited = e.snd.rtt.srttInited + s.Sender.RTTState = e.snd.rtt.TCPRTTState e.snd.rtt.Unlock() if cubic, ok := e.snd.cc.(*cubicState); ok { - s.Sender.Cubic = stack.TCPCubicState{ - WMax: cubic.wMax, - WLastMax: cubic.wLastMax, - T: cubic.t, - TimeSinceLastCongestion: time.Since(cubic.t), - C: cubic.c, - K: cubic.k, - Beta: cubic.beta, - WC: cubic.wC, - WEst: cubic.wEst, - } + s.Sender.Cubic = cubic.TCPCubicState + s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) } - rc := &e.snd.rc - s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, - ReoWnd: rc.reoWnd, - ReoWndIncr: rc.reoWndIncr, - ReoWndPersist: rc.reoWndPersist, - RTTSeq: rc.rttSeq, - } + s.Sender.RACKState = e.snd.rc.TCPRACKState return s } func (e *endpoint) initHardwareGSO() { - gso := &stack.GSO{} switch e.route.NetProto() { case header.IPv4ProtocolNumber: - gso.Type = stack.GSOTCPv4 - gso.L3HdrLen = header.IPv4MinimumSize + e.gso.Type = stack.GSOTCPv4 + e.gso.L3HdrLen = header.IPv4MinimumSize case header.IPv6ProtocolNumber: - gso.Type = stack.GSOTCPv6 - gso.L3HdrLen = header.IPv6MinimumSize + e.gso.Type = stack.GSOTCPv6 + e.gso.L3HdrLen = header.IPv6MinimumSize default: panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } - gso.NeedsCsum = true - gso.CsumOffset = header.TCPChecksumOffset - gso.MaxSize = e.route.GSOMaxSize() - e.gso = gso + e.gso.NeedsCsum = true + e.gso.CsumOffset = header.TCPChecksumOffset + e.gso.MaxSize = e.route.GSOMaxSize() } func (e *endpoint) initGSO() { if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() } else if e.route.HasSoftwareGSOCapability() { - e.gso = &stack.GSO{ + e.gso = stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, NeedsCsum: false, @@ -3200,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool { e.lastOutOfWindowAckTime = now return true } + +// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP. +func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + var ss tcpip.TCPReceiveBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.ReceiveBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index a53d76917..6e9777fe4 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -58,7 +58,7 @@ func (e *endpoint) beforeSave() { if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { panic(&tcpip.ErrSaveRejection{ - Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), }) } e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) @@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if !e.workerRunning { - // The endpoint must be in acceptedChan or has been just + // The endpoint must be in the accepted queue or has been just // disconnected and closed. break } @@ -88,7 +88,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if e.workerRunning { - panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID)) } default: panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) @@ -99,37 +99,19 @@ func (e *endpoint) beforeSave() { } } -// saveAcceptedChan is invoked by stateify. -func (e *endpoint) saveAcceptedChan() []*endpoint { - if e.acceptedChan == nil { - return nil - } - acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan)) - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case ep := <-e.acceptedChan: - acceptedEndpoints[i] = ep - default: - panic("endpoint acceptedChan buffer got consumed by background context") - } - } - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case e.acceptedChan <- acceptedEndpoints[i]: - default: - panic("endpoint acceptedChan buffer got populated by background context") - } +// saveEndpoints is invoked by stateify. +func (a *accepted) saveEndpoints() []*endpoint { + acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) + for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { + acceptedEndpoints[i] = e.Value.(*endpoint) } return acceptedEndpoints } -// loadAcceptedChan is invoked by stateify. -func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { - if cap(acceptedEndpoints) > 0 { - e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints)) - for _, ep := range acceptedEndpoints { - e.acceptedChan <- ep - } +// loadEndpoints is invoked by stateify. +func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { + for _, ep := range acceptedEndpoints { + a.endpoints.PushBack(ep) } } @@ -183,7 +165,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { @@ -198,14 +180,14 @@ func (e *endpoint) Resume(s *stack.Stack) { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) + if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) { + panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) if err != nil { panic("unable to parse BindAddr: " + err.String()) } @@ -231,19 +213,19 @@ func (e *endpoint) Resume(s *stack.Stack) { case epState.connected(): bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.ID.RemoteAddress + e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -263,7 +245,7 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() - backlog := cap(e.acceptedChan) + backlog := e.accepted.cap if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } @@ -281,7 +263,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -328,23 +310,3 @@ func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) } - -// saveMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.measureTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime { - return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) { - r.rttMeasureTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a4667906..a3d1aa1a3 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -75,63 +75,6 @@ const ( ccCubic = "cubic" ) -// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The -// value is protected by a mutex so that we can increment only when it's -// guaranteed not to go above a threshold. -type synRcvdCounter struct { - sync.Mutex - value uint64 - pending sync.WaitGroup - threshold uint64 -} - -// inc tries to increment the global number of endpoints in SYN-RCVD state. It -// succeeds if the increment doesn't make the count go beyond the threshold, and -// fails otherwise. -func (s *synRcvdCounter) inc() bool { - s.Lock() - defer s.Unlock() - if s.value >= s.threshold { - return false - } - - s.pending.Add(1) - s.value++ - - return true -} - -// dec atomically decrements the global number of endpoints in SYN-RCVD -// state. It must only be called if a previous call to inc succeeded. -func (s *synRcvdCounter) dec() { - s.Lock() - defer s.Unlock() - s.value-- - s.pending.Done() -} - -// synCookiesInUse returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func (s *synRcvdCounter) synCookiesInUse() bool { - s.Lock() - defer s.Unlock() - return s.value >= s.threshold -} - -// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. -func (s *synRcvdCounter) SetThreshold(threshold uint64) { - s.Lock() - defer s.Unlock() - s.threshold = threshold -} - -// Threshold returns the current value of synRcvdCounter.Threhsold. -func (s *synRcvdCounter) Threshold() uint64 { - s.Lock() - defer s.Unlock() - return s.threshold -} - type protocol struct { stack *stack.Stack @@ -139,6 +82,7 @@ type protocol struct { sackEnabled bool recovery tcpip.TCPRecovery delayEnabled bool + alwaysUseSynCookies bool sendBufferSize tcpip.TCPSendBufferSizeRangeOption recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string @@ -150,7 +94,6 @@ type protocol struct { minRTO time.Duration maxRTO time.Duration maxRetries uint32 - synRcvdCount synRcvdCounter synRetries uint8 dispatcher dispatcher } @@ -216,8 +159,8 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. -func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { - route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) +func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { + route, err := st.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err } @@ -257,7 +200,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error seq: seq, ack: ack, rcvWnd: 0, - }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) + }, buffer.VectorisedView{}, stack.GSO{}, nil /* PacketOwner */) } // SetOption implements stack.TransportProtocol.SetOption. @@ -373,9 +316,9 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip p.mu.Unlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(*v)) + p.alwaysUseSynCookies = bool(*v) p.mu.Unlock() return nil @@ -480,9 +423,9 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er p.mu.RUnlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.RLock() - *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + *v = tcpip.TCPAlwaysUseSynCookies(p.alwaysUseSynCookies) p.mu.RUnlock() return nil @@ -507,12 +450,6 @@ func (p *protocol) Wait() { p.dispatcher.wait() } -// SynRcvdCounter returns a reference to the synRcvdCount for this protocol -// instance. -func (p *protocol) SynRcvdCounter() *synRcvdCounter { - return &p.synRcvdCount -} - // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { return parse.TCP(pkt) @@ -537,7 +474,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { lingerTimeout: DefaultTCPLingerTimeout, timeWaitTimeout: DefaultTCPTimeWaitTimeout, timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, - synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 0a0d5f7a1..9e332dcf7 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -46,54 +47,16 @@ const ( // // +stateify savable type rackControl struct { - // dsackSeen indicates if the connection has seen a DSACK. - dsackSeen bool - - // endSequence is the ending TCP sequence number of the most recent - // acknowledged segment. - endSequence seqnum.Value + stack.TCPRACKState // exitedRecovery indicates if the connection is exiting loss recovery. // This flag is set if the sender is leaving the recovery after // receiving an ACK and is reset during updating of reorder window. exitedRecovery bool - // fack is the highest selectively or cumulatively acknowledged - // sequence. - fack seqnum.Value - // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool - - // reoWnd is the reordering window time used for recording packet - // transmission times. It is used to defer the moment at which RACK - // marks a packet lost. - reoWnd time.Duration - - // reoWndIncr is the multiplier applied to adjust reorder window. - reoWndIncr uint8 - - // reoWndPersist is the number of loss recoveries before resetting - // reorder window. - reoWndPersist int8 - - // rtt is the RTT of the most recently delivered packet on the - // connection (either cumulatively acknowledged or selectively - // acknowledged) that was not marked invalid as a possible spurious - // retransmission. - rtt time.Duration - - // rttSeq is the SND.NXT when rtt is updated. - rttSeq seqnum.Value - - // xmitTime is the latest transmission timestamp of the most recent - // acknowledged segment. - xmitTime time.Time `state:".(unixTime)"` - // tlpRxtOut indicates whether there is an unacknowledged // TLP retransmission. tlpRxtOut bool @@ -108,8 +71,8 @@ type rackControl struct { // init initializes RACK specific fields. func (rc *rackControl) init(snd *sender, iss seqnum.Value) { - rc.fack = iss - rc.reoWndIncr = 1 + rc.FACK = iss + rc.ReoWndIncr = 1 rc.snd = snd } @@ -117,7 +80,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) - tsOffset := rc.snd.ep.tsOffset + tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: @@ -138,7 +101,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { } } - rc.rtt = rtt + rc.RTT = rtt // The sender can either track a simple global minimum of all RTT // measurements from the connection, or a windowed min-filtered value @@ -152,9 +115,9 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - rc.xmitTime = seg.xmitTime - rc.endSequence = endSeq + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + rc.XmitTime = seg.xmitTime + rc.EndSequence = endSeq } } @@ -171,18 +134,18 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // is identified. func (rc *rackControl) detectReorder(seg *segment) { endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.fack.LessThan(endSeq) { - rc.fack = endSeq + if rc.FACK.LessThan(endSeq) { + rc.FACK = endSeq return } - if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 { - rc.reorderSeen = true + if endSeq.LessThan(rc.FACK) && seg.xmitCount == 1 { + rc.Reord = true } } func (rc *rackControl) setDSACKSeen(dsackSeen bool) { - rc.dsackSeen = dsackSeen + rc.DSACKSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -191,7 +154,7 @@ func (s *sender) shouldSchedulePTO() bool { // Schedule PTO only if RACK loss detection is enabled. return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 && // The connection supports SACK. - s.ep.sackPermitted && + s.ep.SACKPermitted && // The connection is not in loss recovery. (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. @@ -203,9 +166,9 @@ func (s *sender) shouldSchedulePTO() bool { func (s *sender) schedulePTO() { pto := time.Second s.rtt.Lock() - if s.rtt.srttInited && s.rtt.srtt > 0 { - pto = s.rtt.srtt * 2 - if s.outstanding == 1 { + if s.rtt.TCPRTTState.SRTTInited && s.rtt.TCPRTTState.SRTT > 0 { + pto = s.rtt.TCPRTTState.SRTT * 2 + if s.Outstanding == 1 { pto += wcDelayedACKTimeout } } @@ -230,10 +193,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } var dataSent bool - if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.outstanding < s.sndCwnd { - dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.Outstanding < s.SndCwnd { + dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { - s.outstanding += s.pCount(s.writeNext, s.maxPayloadSize) + s.Outstanding += s.pCount(s.writeNext, s.MaxPayloadSize) s.writeNext = s.writeNext.Next() } } @@ -255,10 +218,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } if highestSeqXmit != nil { - dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { s.rc.tlpRxtOut = true - s.rc.tlpHighRxt = s.sndNxt + s.rc.tlpHighRxt = s.SndNxt } } } @@ -274,7 +237,7 @@ func (s *sender) probeTimerExpired() tcpip.Error { // and updates TLP state accordingly. // See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3. func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { - if !(s.ep.sackPermitted && s.rc.tlpRxtOut) { + if !(s.ep.SACKPermitted && s.rc.tlpRxtOut) { return } @@ -317,13 +280,13 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { - dsackSeen := rc.dsackSeen + dsackSeen := rc.DSACKSeen snd := rc.snd // React to DSACK once per round trip. // If SND.UNA < RACK.rtt_seq: // RACK.dsack = false - if snd.sndUna.LessThan(rc.rttSeq) { + if snd.SndUna.LessThan(rc.RTTSeq) { dsackSeen = false } @@ -333,18 +296,18 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.rtt_seq = SND.NXT // RACK.reo_wnd_persist = 16 if dsackSeen { - rc.reoWndIncr++ + rc.ReoWndIncr++ dsackSeen = false - rc.rttSeq = snd.sndNxt - rc.reoWndPersist = tcpRACKRecoveryThreshold + rc.RTTSeq = snd.SndNxt + rc.ReoWndPersist = tcpRACKRecoveryThreshold } else if rc.exitedRecovery { // Else if exiting loss recovery: // RACK.reo_wnd_persist -= 1 // If RACK.reo_wnd_persist <= 0: // RACK.reo_wnd_incr = 1 - rc.reoWndPersist-- - if rc.reoWndPersist <= 0 { - rc.reoWndIncr = 1 + rc.ReoWndPersist-- + if rc.ReoWndPersist <= 0 { + rc.ReoWndIncr = 1 } rc.exitedRecovery = false } @@ -358,14 +321,14 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // Else if RACK.pkts_sacked >= RACK.dupthresh: // RACK.reo_wnd = 0 // return - if !rc.reorderSeen { + if !rc.Reord { if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { - rc.reoWnd = 0 + rc.ReoWnd = 0 return } - if snd.sackedOut >= nDupAckThreshold { - rc.reoWnd = 0 + if snd.SackedOut >= nDupAckThreshold { + rc.ReoWnd = 0 return } } @@ -374,11 +337,11 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) snd.rtt.Lock() - srtt := snd.rtt.srtt + srtt := snd.rtt.TCPRTTState.SRTT snd.rtt.Unlock() - rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) - if srtt < rc.reoWnd { - rc.reoWnd = srtt + rc.ReoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.ReoWndIncr)) + if srtt < rc.ReoWnd { + rc.ReoWnd = srtt } } @@ -403,8 +366,8 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true numLost++ @@ -435,7 +398,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { } fastRetransmit := false - if !rc.snd.fr.active { + if !rc.snd.FastRecovery.Active { rc.snd.cc.HandleLossDetected() rc.snd.enterRecovery() fastRetransmit = true @@ -471,15 +434,15 @@ func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) { } // Check the congestion window after entering recovery. - if snd.outstanding >= snd.sndCwnd { + if snd.Outstanding >= snd.SndCwnd { break } - if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.sndUna.Add(snd.sndWnd)); !sent { + if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.SndUna.Add(snd.SndWnd)); !sent { break } dataSent = true - snd.outstanding += snd.pCount(seg, snd.maxPayloadSize) + snd.Outstanding += snd.pCount(seg, snd.MaxPayloadSize) } snd.postXmit(dataSent, true /* shouldScheduleProbe */) diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index bc6793fc6..ee2c08cd6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // receiver holds the state necessary to receive TCP segments and turn them @@ -29,26 +30,15 @@ import ( // // +stateify savable type receiver struct { + stack.TCPReceiverState ep *endpoint - rcvNxt seqnum.Value - - // rcvAcc is one beyond the last acceptable sequence number. That is, - // the "largest" sequence value that the receiver has announced to the - // its peer that it's willing to accept. This may be different than - // rcvNxt + rcvWnd if the receive window is reduced; in that case we - // have to reduce the window as we receive more data instead of - // shrinking it. - rcvAcc seqnum.Value - // rcvWnd is the non-scaled receive window last advertised to the peer. rcvWnd seqnum.Size - // rcvWUP is the rcvNxt value at the last window update sent. + // rcvWUP is the RcvNxt value at the last window update sent. rcvWUP seqnum.Value - rcvWndScale uint8 - // prevBufused is the snapshot of endpoint rcvBufUsed taken when we // advertise a receive window. prevBufUsed int @@ -58,9 +48,6 @@ type receiver struct { // pendingRcvdSegments is bounded by the receive buffer size of the // endpoint. pendingRcvdSegments segmentHeap - // pendingBufUsed tracks the total number of bytes (including segment - // overhead) currently queued in pendingRcvdSegments. - pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` @@ -68,12 +55,14 @@ type receiver struct { func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), + ep: ep, + TCPReceiverState: stack.TCPReceiverState{ + RcvNxt: irs + 1, + RcvAcc: irs.Add(rcvWnd + 1), + RcvWndScale: rcvWndScale, + }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - rcvWndScale: rcvWndScale, lastRcvdAckTime: time.Now(), } } @@ -84,34 +73,34 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // r.rcvWnd could be much larger than the window size we advertised in our // outgoing packets, we should use what we have advertised for acceptability // test. - scaledWindowSize := r.rcvWnd >> r.rcvWndScale + scaledWindowSize := r.rcvWnd >> r.RcvWndScale if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. scaledWindowSize = math.MaxUint16 } - advertisedWindowSize := scaledWindowSize << r.rcvWndScale - return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) + advertisedWindowSize := scaledWindowSize << r.RcvWndScale + return header.Acceptable(segSeq, segLen, r.RcvNxt, r.RcvNxt.Add(advertisedWindowSize)) } // currentWindow returns the available space in the window that was advertised // last to our peer. func (r *receiver) currentWindow() (curWnd seqnum.Size) { endOfWnd := r.rcvWUP.Add(r.rcvWnd) - if endOfWnd.LessThan(r.rcvNxt) { - // return 0 if r.rcvNxt is past the end of the previously advertised window. + if endOfWnd.LessThan(r.RcvNxt) { + // return 0 if r.RcvNxt is past the end of the previously advertised window. // This can happen because we accept a large segment completely even if // accepting it causes it to partially exceed the advertised window. return 0 } - return r.rcvNxt.Size(endOfWnd) + return r.RcvNxt.Size(endOfWnd) } // getSendParams returns the parameters needed by the sender when building // segments to send. -func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { +func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() - unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + unackLen := int(r.ep.snd.MaxSentAck.Size(r.RcvNxt)) bufUsed := r.ep.receiveBufferUsed() // Grow the right edge of the window only for payloads larger than the @@ -139,18 +128,18 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // edge, as we are still advertising a window that we think can be serviced. toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed - // Update rcvAcc only if new window is > previously advertised window. We + // Update RcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we // would end up dropping bytes that might already be in flight. // ==================================================== sequence space. // ^ ^ ^ ^ - // rcvWUP rcvNxt rcvAcc new rcvAcc + // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { - // If the new window moves the right edge, then update rcvAcc. - r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) + if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + // If the new window moves the right edge, then update RcvAcc. + r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -162,9 +151,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd - r.rcvWUP = r.rcvNxt + r.rcvWUP = r.RcvNxt r.prevBufUsed = bufUsed - scaledWnd := r.rcvWnd >> r.rcvWndScale + scaledWnd := r.rcvWnd >> r.RcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() @@ -177,9 +166,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Ensure that the stashed receive window always reflects what // is being advertised. - r.rcvWnd = scaledWnd << r.rcvWndScale + r.rcvWnd = scaledWnd << r.RcvWndScale } - return r.rcvNxt, scaledWnd + return r.RcvNxt, scaledWnd } // nonZeroWindow is called when the receive window grows from zero to nonzero; @@ -201,13 +190,13 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // If the segment doesn't include the seqnum we're expecting to // consume now, we're missing a segment. We cannot proceed until // we receive that segment though. - if !r.rcvNxt.InWindow(segSeq, segLen) { + if !r.RcvNxt.InWindow(segSeq, segLen) { return false } // Trim segment to eliminate already acknowledged data. - if segSeq.LessThan(r.rcvNxt) { - diff := segSeq.Size(r.rcvNxt) + if segSeq.LessThan(r.RcvNxt) { + diff := segSeq.Size(r.RcvNxt) segLen -= diff segSeq.UpdateForward(diff) s.sequenceNumber.UpdateForward(diff) @@ -217,35 +206,35 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Move segment to ready-to-deliver list. Wakeup any waiters. r.ep.readyToRead(s) - } else if segSeq != r.rcvNxt { + } else if segSeq != r.RcvNxt { return false } // Update the segment that we're expecting to consume. - r.rcvNxt = segSeq.Add(segLen) + r.RcvNxt = segSeq.Add(segLen) // In cases of a misbehaving sender which could send more than the // advertised window, we could end up in a situation where we get a // segment that exceeds the window advertised. Instead of partially // accepting the segment and discarding bytes beyond the advertised - // window, we accept the whole segment and make sure r.rcvAcc is moved - // forward to match r.rcvNxt to indicate that the window is now closed. + // window, we accept the whole segment and make sure r.RcvAcc is moved + // forward to match r.RcvNxt to indicate that the window is now closed. // // In absence of this check the r.acceptable() check fails and accepts // segments that should be dropped because rcvWnd is calculated as - // the size of the interval (rcvNxt, rcvAcc] which becomes extremely - // large if rcvAcc is ever less than rcvNxt. - if r.rcvAcc.LessThan(r.rcvNxt) { - r.rcvAcc = r.rcvNxt + // the size of the interval (RcvNxt, RcvAcc] which becomes extremely + // large if RcvAcc is ever less than RcvNxt. + if r.RcvAcc.LessThan(r.RcvNxt) { + r.RcvAcc = r.RcvNxt } // Trim SACK Blocks to remove any SACK information that covers // sequence numbers that have been consumed. - TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { - r.rcvNxt++ + r.RcvNxt++ // Send ACK immediately. r.ep.snd.sendAck() @@ -260,7 +249,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -280,7 +269,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { - r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() + r.PendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() // Note that slice truncation does not allow garbage collection of @@ -295,7 +284,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -323,40 +312,40 @@ func (r *receiver) updateRTT() { // estimate the round-trip time by observing the time between when a byte // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. - r.ep.rcvListMu.Lock() - if r.ep.rcvAutoParams.rttMeasureTime.IsZero() { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { // New measurement. - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) { - r.ep.rcvListMu.Unlock() + if r.RcvNxt.LessThan(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber) { + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime) + rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. - if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt { - r.ep.rcvAutoParams.rtt = rtt + if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { + r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { - r.ep.rcvListMu.Lock() - rcvClosed := r.ep.rcvClosed || r.closed - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := r.ep.rcvQueueInfo.RcvClosed || r.closed + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { case StateCloseWait, StateClosing, StateLastAck: - if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + if !s.sequenceNumber.LessThanEq(r.RcvNxt) { // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -384,17 +373,17 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // The ESTABLISHED state processing is here where if the ACK check // fails, we ignore the packet: // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + if r.ep.snd.SndNxt.LessThan(s.ackNumber) { r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., - // SHUT_RD) then any data past the rcvNxt should + // SHUT_RD) then any data past the RcvNxt should // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + if state != StateCloseWait && rcvClosed && r.RcvNxt.LessThan(endDataSeq) { return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { @@ -403,7 +392,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If it's a retransmission of an old data segment // or a pure ACK then allow it. - if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.RcvNxt) || s.logicalLen() == 0 { break } @@ -413,7 +402,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // then the only acceptable segment is a // FIN. Since FIN can technically also carry // data we verify that the segment carrying a - // FIN ends at exactly e.rcvNxt+1. + // FIN ends at exactly e.RcvNxt+1. // // From RFC793 page 25. // @@ -423,7 +412,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -435,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // end has closed and the peer is yet to send a FIN. Hence we // compare only the payload. segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + if rcvClosed && !segEnd.LessThanEq(r.RcvNxt) { return true, nil } return false, nil @@ -477,13 +466,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { - r.ep.rcvListMu.Lock() - r.pendingBufUsed += s.segMemSize() - r.ep.rcvListMu.Unlock() + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed += s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) - UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.RcvNxt) } // Immediately send an ack so that the peer knows it may @@ -508,15 +497,15 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { segSeq := s.sequenceNumber // Skip segment altogether if it has already been acknowledged. - if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + if !segSeq.Add(segLen-1).LessThan(r.RcvNxt) && !r.consumeSegment(s, segSeq, segLen) { break } heap.Pop(&r.pendingRcvdSegments) - r.ep.rcvListMu.Lock() - r.pendingBufUsed -= s.segMemSize() - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed -= s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.decRef() } return false, nil @@ -558,7 +547,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } @@ -569,11 +558,11 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn } // Update Timestamp if required. See RFC7323, section-4.3. - if r.ep.sendTSOk && s.parsedOptions.TS { - r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + if r.ep.SendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() @@ -584,8 +573,8 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // carries data then just send an ACK. This is according to RFC 793, // page 37. // - // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. - if segSeq != r.rcvNxt || segLen != 0 { + // NOTE: In TIME_WAIT the only acceptable sequence number is RcvNxt. + if segSeq != r.RcvNxt || segLen != 0 { r.ep.snd.sendAck() } return false, false diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go index ff39780a5..063552c7f 100644 --- a/pkg/tcpip/transport/tcp/reno.go +++ b/pkg/tcpip/transport/tcp/reno.go @@ -34,14 +34,14 @@ func newRenoCC(s *sender) *renoState { func (r *renoState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := r.s.sndCwnd + packetsAcked - if newcwnd >= r.s.sndSsthresh { - newcwnd = r.s.sndSsthresh - r.s.sndCAAckCount = 0 + newcwnd := r.s.SndCwnd + packetsAcked + if newcwnd >= r.s.Ssthresh { + newcwnd = r.s.Ssthresh + r.s.SndCAAckCount = 0 } - packetsAcked -= newcwnd - r.s.sndCwnd - r.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - r.s.SndCwnd + r.s.SndCwnd = newcwnd return packetsAcked } @@ -49,19 +49,19 @@ func (r *renoState) updateSlowStart(packetsAcked int) int { // avoidance mode as described in RFC5681 section 3.1 func (r *renoState) updateCongestionAvoidance(packetsAcked int) { // Consume the packets in congestion avoidance mode. - r.s.sndCAAckCount += packetsAcked - if r.s.sndCAAckCount >= r.s.sndCwnd { - r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd - r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + r.s.SndCAAckCount += packetsAcked + if r.s.SndCAAckCount >= r.s.SndCwnd { + r.s.SndCwnd += r.s.SndCAAckCount / r.s.SndCwnd + r.s.SndCAAckCount = r.s.SndCAAckCount % r.s.SndCwnd } } // reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, // page 6, eq. 4. It is called when we detect congestion in the network. func (r *renoState) reduceSlowStartThreshold() { - r.s.sndSsthresh = r.s.outstanding / 2 - if r.s.sndSsthresh < 2 { - r.s.sndSsthresh = 2 + r.s.Ssthresh = r.s.Outstanding / 2 + if r.s.Ssthresh < 2 { + r.s.Ssthresh = 2 } } @@ -70,7 +70,7 @@ func (r *renoState) reduceSlowStartThreshold() { // were acknowledged. // Update implements congestionControl.Update. func (r *renoState) Update(packetsAcked int) { - if r.s.sndCwnd < r.s.sndSsthresh { + if r.s.SndCwnd < r.s.Ssthresh { packetsAcked = r.updateSlowStart(packetsAcked) if packetsAcked == 0 { return @@ -94,7 +94,7 @@ func (r *renoState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - r.s.sndCwnd = 1 + r.s.SndCwnd = 1 } // PostRecovery implements congestionControl.PostRecovery. diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go index 2aa708e97..d368a29fc 100644 --- a/pkg/tcpip/transport/tcp/reno_recovery.go +++ b/pkg/tcpip/transport/tcp/reno_recovery.go @@ -31,25 +31,25 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { snd := rr.s // We are in fast recovery mode. Ignore the ack if it's out of range. - if !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // Don't count this as a duplicate if it is carrying data or // updating the window. - if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window { + if rcvdSeg.logicalLen() != 0 || snd.SndWnd != rcvdSeg.window { return } // Inflate the congestion window if we're getting duplicate acks // for the packet we retransmitted. - if !fastRetransmit && ack == snd.fr.first { + if !fastRetransmit && ack == snd.FastRecovery.First { // We received a dup, inflate the congestion window by 1 packet // if we're not at the max yet. Only inflate the window if // regular FastRecovery is in use, RFC6675 does not require // inflating cwnd on duplicate ACKs. - if snd.sndCwnd < snd.fr.maxCwnd { - snd.sndCwnd++ + if snd.SndCwnd < snd.FastRecovery.MaxCwnd { + snd.SndCwnd++ } return } @@ -61,7 +61,7 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { // back onto the wire. // // N.B. The retransmit timer will be reset by the caller. - snd.fr.first = ack - snd.dupAckCount = 0 + snd.FastRecovery.First = ack + snd.DupAckCount = 0 snd.resendSegment() } diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go index 9d406b0bc..cd860b5e8 100644 --- a/pkg/tcpip/transport/tcp/sack_recovery.go +++ b/pkg/tcpip/transport/tcp/sack_recovery.go @@ -42,14 +42,14 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen } nextSegHint := snd.writeList.Front() - for snd.outstanding < snd.sndCwnd { + for snd.Outstanding < snd.SndCwnd { var nextSeg *segment var rescueRtx bool nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint) if nextSeg == nil { return dataSent } - if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + if !snd.isAssignedSequenceNumber(nextSeg) || snd.SndNxt.LessThanEq(nextSeg.sequenceNumber) { // New data being sent. // Step C.3 described below is handled by @@ -67,7 +67,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen return dataSent } dataSent = true - snd.outstanding++ + snd.Outstanding++ snd.writeNext = nextSeg.Next() continue } @@ -79,7 +79,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // "The estimate of the amount of data outstanding in the network // must be updated by incrementing pipe by the number of octets // transmitted in (C.1)." - snd.outstanding++ + snd.Outstanding++ dataSent = true snd.sendSegment(nextSeg) @@ -88,7 +88,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // We do the last part of rule (4) of NextSeg here to update // RescueRxt as until this point we don't know if we are going // to use the rescue transmission. - snd.fr.rescueRxt = snd.fr.last + snd.FastRecovery.RescueRxt = snd.FastRecovery.Last } else { // RFC 6675, Step C.2 // @@ -96,7 +96,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // HighData, HighRxt MUST be set to the highest sequence // number of the retransmitted segment unless NextSeg () // rule (4) was invoked for this retransmission." - snd.fr.highRxt = segEnd - 1 + snd.FastRecovery.HighRxt = segEnd - 1 } } return dataSent @@ -109,12 +109,12 @@ func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { } // We are in fast recovery mode. Ignore the ack if it's out of range. - if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if ack := rcvdSeg.ackNumber; !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // RFC 6675 recovery algorithm step C 1-5. - end := snd.sndUna.Add(snd.sndWnd) - dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end) + end := snd.SndUna.Add(snd.SndWnd) + dataSent := sr.handleSACKRecovery(snd.MaxPayloadSize, end) snd.postXmit(dataSent, true /* shouldScheduleProbe */) } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 8edd6775b..c28641be3 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -236,20 +236,14 @@ func (s *segment) parse(skipChecksumValidation bool) bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - - verifyChecksum := true if skipChecksumValidation { s.csumValid = true - verifyChecksum = false - } - if verifyChecksum { + } else { s.csum = s.hdr.Checksum() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) - xsum = s.hdr.CalculateChecksum(xsum) - xsum = header.ChecksumVV(s.data, xsum) - s.csumValid = xsum == 0xffff + payloadChecksum := header.ChecksumVV(s.data, 0) + payloadLength := uint16(s.data.Size()) + s.csumValid = s.hdr.IsChecksumValid(s.srcAddr, s.dstAddr, payloadChecksum, payloadLength) } - s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) s.ackNumber = seqnum.Value(s.hdr.AckNumber()) s.flags = s.hdr.Flags() diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 54545a1b1..d0d1b0b8a 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool { func (q *segmentQueue) enqueue(s *segment) bool { // q.ep.receiveBufferParams() must be called without holding q.mu to // avoid lock order inversion. - bufSz := q.ep.receiveBufferSize() + bufSz := q.ep.ops.GetReceiveBufferSize() used := q.ep.receiveMemUsed() q.mu.Lock() // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue // is currently full). - allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen if allow { q.list.PushBack(s) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index faca35892..2b32cb7b2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -85,56 +86,12 @@ type lossRecovery interface { // // +stateify savable type sender struct { + stack.TCPSenderState ep *endpoint - // lastSendTime is the timestamp when the last packet was sent. - lastSendTime time.Time `state:".(unixTime)"` - - // dupAckCount is the number of duplicated acks received. It is used for - // fast retransmit. - dupAckCount int - - // fr holds state related to fast recovery. - fr fastRecovery - // lr is the loss recovery algorithm used by the sender. lr lossRecovery - // sndCwnd is the congestion window, in packets. - sndCwnd int - - // sndSsthresh is the threshold between slow start and congestion - // avoidance. - sndSsthresh int - - // sndCAAckCount is the number of packets acknowledged during congestion - // avoidance. When enough packets have been ack'd (typically cwnd - // packets), the congestion window is incremented by one. - sndCAAckCount int - - // outstanding is the number of outstanding packets, that is, packets - // that have been sent but not yet acknowledged. - outstanding int - - // sackedOut is the number of packets which are selectively acked. - sackedOut int - - // sndWnd is the send window size. - sndWnd seqnum.Size - - // sndUna is the next unacknowledged sequence number. - sndUna seqnum.Value - - // sndNxt is the sequence number of the next segment to be sent. - sndNxt seqnum.Value - - // rttMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - rttMeasureSeqNum seqnum.Value - - // rttMeasureTime is the time when the rttMeasureSeqNum was sent. - rttMeasureTime time.Time `state:".(unixTime)"` - // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` @@ -147,17 +104,15 @@ type sender struct { // window probes. unackZeroWindowProbes uint32 `state:"nosave"` - closed bool writeNext *segment writeList segmentList resendTimer timer `state:"nosave"` resendWaker sleep.Waker `state:"nosave"` - // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", - // "round-trip time variation" and "retransmit timeout", as defined in + // rtt.TCPRTTState.SRTT and rtt.TCPRTTState.RTTVar are the "smoothed + // round-trip time", and "round-trip time variation", as defined in // section 2 of RFC 6298. rtt rtt - rto time.Duration // minRTO is the minimum permitted value for sender.rto. minRTO time.Duration @@ -168,20 +123,9 @@ type sender struct { // maxRetries is the maximum permitted retransmissions. maxRetries uint32 - // maxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - maxPayloadSize int - // gso is set if generic segmentation offload is enabled. gso bool - // sndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - sndWndScale uint8 - - // maxSentAck is the maxium acknowledgement actually sent. - maxSentAck seqnum.Value - // state is the current state of congestion control for this endpoint. state tcpip.CongestionControlState @@ -209,41 +153,7 @@ type sender struct { type rtt struct { sync.Mutex `state:"nosave"` - srtt time.Duration - rttvar time.Duration - srttInited bool -} - -// fastRecovery holds information related to fast recovery from a packet loss. -// -// +stateify savable -type fastRecovery struct { - // active whether the endpoint is in fast recovery. The following fields - // are only meaningful when active is true. - active bool - - // first and last represent the inclusive sequence number range being - // recovered. - first seqnum.Value - last seqnum.Value - - // maxCwnd is the maximum value the congestion window may be inflated to - // due to duplicate acks. This exists to avoid attacks where the - // receiver intentionally sends duplicate acks to artificially inflate - // the sender's cwnd. - maxCwnd int - - // highRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - highRxt seqnum.Value - - // rescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - rescueRxt seqnum.Value + stack.TCPRTTState } func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { @@ -253,22 +163,24 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint maxPayloadSize := int(mss) - ep.maxOptionSize() s := &sender{ - ep: ep, - sndWnd: sndWnd, - sndUna: iss + 1, - sndNxt: iss + 1, - rto: 1 * time.Second, - rttMeasureSeqNum: iss + 1, - lastSendTime: time.Now(), - maxPayloadSize: maxPayloadSize, - maxSentAck: irs + 1, - fr: fastRecovery{ - // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. - last: iss, - highRxt: iss, - rescueRxt: iss, + ep: ep, + TCPSenderState: stack.TCPSenderState{ + SndWnd: sndWnd, + SndUna: iss + 1, + SndNxt: iss + 1, + RTTMeasureSeqNum: iss + 1, + LastSendTime: time.Now(), + MaxPayloadSize: maxPayloadSize, + MaxSentAck: irs + 1, + FastRecovery: stack.TCPFastRecoveryState{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + Last: iss, + HighRxt: iss, + RescueRxt: iss, + }, + RTO: 1 * time.Second, }, - gso: ep.gso != nil, + gso: ep.gso.Type != stack.GSONone, } if s.gso { @@ -282,7 +194,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { - s.sndWndScale = uint8(sndWndScale) + s.SndWndScale = uint8(sndWndScale) } s.resendTimer.init(&s.resendWaker) @@ -294,7 +206,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // Initialize SACK Scoreboard after updating max payload size as we use // the maxPayloadSize as the smss when determining if a segment is lost // etc. - s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + s.ep.scoreboard = NewSACKScoreboard(uint16(s.MaxPayloadSize), iss) // Get Stack wide config. var minRTO tcpip.TCPMinRTOOption @@ -322,10 +234,10 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // returns a handle to it. It also initializes the sndCwnd and sndSsThresh to // their initial values. func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { - s.sndCwnd = InitialCwnd + s.SndCwnd = InitialCwnd // Set sndSsthresh to the maximum int value, which depends on the // platform. - s.sndSsthresh = int(^uint(0) >> 1) + s.Ssthresh = int(^uint(0) >> 1) switch congestionControlName { case ccCubic: @@ -339,7 +251,7 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon // initLossRecovery initiates the loss recovery algorithm for the sender. func (s *sender) initLossRecovery() lossRecovery { - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return newSACKRecovery(s) } return newRenoRecovery(s) @@ -355,7 +267,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m -= s.ep.maxOptionSize() // We don't adjust up for now. - if m >= s.maxPayloadSize { + if m >= s.MaxPayloadSize { return } @@ -364,8 +276,8 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } - oldMSS := s.maxPayloadSize - s.maxPayloadSize = m + oldMSS := s.MaxPayloadSize + s.MaxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) } @@ -380,9 +292,9 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // maxPayloadSize. s.ep.scoreboard.smss = uint16(m) - s.outstanding -= count - if s.outstanding < 0 { - s.outstanding = 0 + s.Outstanding -= count + if s.Outstanding < 0 { + s.Outstanding = 0 } // Rewind writeNext to the first segment exceeding the MTU. Do nothing @@ -401,10 +313,10 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { nextSeg = seg } - if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Update sackedOut for new maximum payload size. - s.sackedOut -= s.pCount(seg, oldMSS) - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, oldMSS) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } } @@ -416,32 +328,32 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // sendAck sends an ACK segment. func (s *sender) sendAck() { - s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) + s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.SndNxt) } // updateRTO updates the retransmit timeout when a new roud-trip time is // available. This is done in accordance with section 2 of RFC 6298. func (s *sender) updateRTO(rtt time.Duration) { s.rtt.Lock() - if !s.rtt.srttInited { - s.rtt.rttvar = rtt / 2 - s.rtt.srtt = rtt - s.rtt.srttInited = true + if !s.rtt.TCPRTTState.SRTTInited { + s.rtt.TCPRTTState.RTTVar = rtt / 2 + s.rtt.TCPRTTState.SRTT = rtt + s.rtt.TCPRTTState.SRTTInited = true } else { - diff := s.rtt.srtt - rtt + diff := s.rtt.TCPRTTState.SRTT - rtt if diff < 0 { diff = -diff } - // Use RFC6298 standard algorithm to update rttvar and srtt when + // Use RFC6298 standard algorithm to update TCPRTTState.RTTVar and TCPRTTState.SRTT when // no timestamps are available. - if !s.ep.sendTSOk { - s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 - s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 + if !s.ep.SendTSOk { + s.rtt.TCPRTTState.RTTVar = (3*s.rtt.TCPRTTState.RTTVar + diff) / 4 + s.rtt.TCPRTTState.SRTT = (7*s.rtt.TCPRTTState.SRTT + rtt) / 8 } else { // When we are taking RTT measurements of every ACK then // we need to use a modified method as specified in // https://tools.ietf.org/html/rfc7323#appendix-G - if s.outstanding == 0 { + if s.Outstanding == 0 { s.rtt.Unlock() return } @@ -449,7 +361,7 @@ func (s *sender) updateRTO(rtt time.Duration) { // terms of packets and not bytes. This is similar to // how linux also does cwnd and inflight. In practice // this approximation works as expected. - expectedSamples := math.Ceil(float64(s.outstanding) / 2) + expectedSamples := math.Ceil(float64(s.Outstanding) / 2) // alpha & beta values are the original values as recommended in // https://tools.ietf.org/html/rfc6298#section-2.3. @@ -458,17 +370,17 @@ func (s *sender) updateRTO(rtt time.Duration) { alphaPrime := alpha / expectedSamples betaPrime := beta / expectedSamples - rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds() - srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds() - s.rtt.rttvar = time.Duration(rttVar * float64(time.Second)) - s.rtt.srtt = time.Duration(srtt * float64(time.Second)) + rttVar := (1-betaPrime)*s.rtt.TCPRTTState.RTTVar.Seconds() + betaPrime*diff.Seconds() + srtt := (1-alphaPrime)*s.rtt.TCPRTTState.SRTT.Seconds() + alphaPrime*rtt.Seconds() + s.rtt.TCPRTTState.RTTVar = time.Duration(rttVar * float64(time.Second)) + s.rtt.TCPRTTState.SRTT = time.Duration(srtt * float64(time.Second)) } } - s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.RTO = s.rtt.TCPRTTState.SRTT + 4*s.rtt.TCPRTTState.RTTVar s.rtt.Unlock() - if s.rto < s.minRTO { - s.rto = s.minRTO + if s.RTO < s.minRTO { + s.RTO = s.minRTO } } @@ -476,20 +388,20 @@ func (s *sender) updateRTO(rtt time.Duration) { func (s *sender) resendSegment() { // Don't use any segments we already sent to measure RTT as they may // have been affected by packets being lost. - s.rttMeasureSeqNum = s.sndNxt + s.RTTMeasureSeqNum = s.SndNxt // Resend the segment. if seg := s.writeList.Front(); seg != nil { - if seg.data.Size() > s.maxPayloadSize { - s.splitSeg(seg, s.maxPayloadSize) + if seg.data.Size() > s.MaxPayloadSize { + s.splitSeg(seg, s.MaxPayloadSize) } // See: RFC 6675 section 5 Step 4.3 // // To prevent retransmission, set both the HighRXT and RescueRXT // to the highest sequence number in the retransmitted segment. - s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 - s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.HighRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.RescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() s.ep.stats.SendErrors.FastRetransmit.Increment() @@ -554,15 +466,15 @@ func (s *sender) retransmitTimerExpired() bool { // Set new timeout. The timer will be restarted by the call to sendData // below. - s.rto *= 2 + s.RTO *= 2 // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 - if s.rto > s.maxRTO { - s.rto = s.maxRTO + if s.RTO > s.maxRTO { + s.RTO = s.maxRTO } // Cap RTO to remaining time. - if s.rto > remaining { - s.rto = remaining + if s.RTO > remaining { + s.RTO = remaining } // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. @@ -571,9 +483,9 @@ func (s *sender) retransmitTimerExpired() bool { // After a retransmit timeout, record the highest sequence number // transmitted in the variable recover, and exit the fast recovery // procedure if applicable. - s.fr.last = s.sndNxt - 1 + s.FastRecovery.Last = s.SndNxt - 1 - if s.fr.active { + if s.FastRecovery.Active { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it // has already been updated when entered fast-recovery. @@ -589,7 +501,7 @@ func (s *sender) retransmitTimerExpired() bool { // // We'll keep on transmitting (or retransmitting) as we get acks for // the data we transmit. - s.outstanding = 0 + s.Outstanding = 0 // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1 // @@ -663,7 +575,7 @@ func (s *sender) splitSeg(seg *segment, size int) { // window space. // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() - if seg.data.Size() > s.maxPayloadSize { + if seg.data.Size() > s.MaxPayloadSize { seg.flags ^= header.TCPFlagPsh } @@ -689,7 +601,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // transmitted (i.e. either it has no assigned sequence number // or if it does have one, it's >= the next sequence number // to be sent [i.e. >= s.sndNxt]). - if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) { + if !s.isAssignedSequenceNumber(seg) || s.SndNxt.LessThanEq(seg.sequenceNumber) { hint = nil break } @@ -710,7 +622,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // (1.a) S2 is greater than HighRxt // (1.b) S2 is less than highest octect covered by // any received SACK. - if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { + if s.FastRecovery.HighRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { // NextSeg(): // (1.c) IsLost(S2) returns true. if s.ep.scoreboard.IsLost(segSeq) { @@ -743,7 +655,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // unSACKed sequence number SHOULD be returned, and // RescueRxt set to RecoveryPoint. HighRxt MUST NOT // be updated. - if s.fr.rescueRxt.LessThan(s.sndUna - 1) { + if s.FastRecovery.RescueRxt.LessThan(s.SndUna - 1) { if s4 != nil { if s4.sequenceNumber.LessThan(segSeq) { s4 = seg @@ -763,7 +675,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // previously unsent data starting with sequence number // HighData+1 MUST be returned." for seg := s.writeNext; seg != nil; seg = seg.Next() { - if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.SndNxt) { continue } // We do not split the segment here to <= smss as it has @@ -788,7 +700,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(s.sndNxt.Size(end)) + available := int(s.SndNxt.Size(end)) if available > limit { available = limit } @@ -816,7 +728,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } if !nextTooBig && seg.data.Size() < available { // Segment is not full. - if s.outstanding > 0 && s.ep.ops.GetDelayOption() { + if s.Outstanding > 0 && s.ep.ops.GetDelayOption() { // Nagle's algorithm. From Wikipedia: // Nagle's algorithm works by // combining a number of small @@ -835,7 +747,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // send space and MSS. // TODO(gvisor.dev/issue/2833): Drain the held segments after a // timeout. - if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() { + if seg.data.Size() < s.MaxPayloadSize && s.ep.ops.GetCorkOption() { return false } } @@ -843,7 +755,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Assign flags. We don't do it above so that we can merge // additional data if Nagle holds the segment. - seg.sequenceNumber = s.sndNxt + seg.sequenceNumber = s.SndNxt seg.flags = header.TCPFlagAck | header.TCPFlagPsh } @@ -893,12 +805,12 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // the segment right here if there are no pending segments. If // there are pending segments, segment transmits are deferred to // the retransmit timer handler. - if s.sndUna != s.sndNxt { + if s.SndUna != s.SndNxt { switch { case available >= seg.data.Size(): // OK to send, the whole segments fits in the // receiver's advertised window. - case available >= s.maxPayloadSize: + case available >= s.MaxPayloadSize: // OK to send, at least 1 MSS sized segment fits // in the receiver's advertised window. default: @@ -918,8 +830,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // If GSO is not in use then cap available to // maxPayloadSize. When GSO is in use the gVisor GSO logic or // the host GSO logic will cap the segment to the correct size. - if s.ep.gso == nil && available > s.maxPayloadSize { - available = s.maxPayloadSize + if s.ep.gso.Type == stack.GSONone && available > s.MaxPayloadSize { + available = s.MaxPayloadSize } if seg.data.Size() > available { @@ -933,8 +845,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Update sndNxt if we actually sent new data (as opposed to // retransmitting some previously sent data). - if s.sndNxt.LessThan(segEnd) { - s.sndNxt = segEnd + if s.SndNxt.LessThan(segEnd) { + s.SndNxt = segEnd } return true @@ -945,9 +857,9 @@ func (s *sender) sendZeroWindowProbe() { s.unackZeroWindowProbes++ // Send a zero window probe with sequence number pointing to // the last acknowledged byte. - s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win) + s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.SndUna-1, ack, win) // Rearm the timer to continue probing. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) enableZeroWindowProbing() { @@ -958,7 +870,7 @@ func (s *sender) enableZeroWindowProbing() { if s.firstRetransmittedSegXmitTime.IsZero() { s.firstRetransmittedSegXmitTime = time.Now() } - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) disableZeroWindowProbing() { @@ -978,12 +890,12 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // If the sender has advertized zero receive window and we have // data to be sent out, start zero window probing to query the // the remote for it's receive window size. - if s.writeNext != nil && s.sndWnd == 0 { + if s.writeNext != nil && s.SndWnd == 0 { s.enableZeroWindowProbing() } // If we have no more pending data, start the keepalive timer. - if s.sndUna == s.sndNxt { + if s.SndUna == s.SndNxt { s.ep.resetKeepaliveTimer(false) } else { // Enable timers if we have pending data. @@ -992,10 +904,10 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { s.schedulePTO() } else if !s.resendTimer.enabled() { s.probeTimer.disable() - if s.outstanding > 0 { + if s.Outstanding > 0 { // Enable the resend timer if it's not enabled yet and there is // outstanding data. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1004,29 +916,29 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { - limit := s.maxPayloadSize + limit := s.MaxPayloadSize if s.gso { limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) } - end := s.sndUna.Add(s.sndWnd) + end := s.SndUna.Add(s.SndWnd) // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { - if s.sndCwnd > InitialCwnd { - s.sndCwnd = InitialCwnd + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if s.SndCwnd > InitialCwnd { + s.SndCwnd = InitialCwnd } } var dataSent bool - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + for seg := s.writeNext; seg != nil && s.Outstanding < s.SndCwnd; seg = seg.Next() { + cwndLimit := (s.SndCwnd - s.Outstanding) * s.MaxPayloadSize if cwndLimit < limit { limit = cwndLimit } - if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.isAssignedSequenceNumber(seg) && s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Move writeNext along so that we don't try and scan data that // has already been SACKED. s.writeNext = seg.Next() @@ -1036,7 +948,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding += s.pCount(seg, s.maxPayloadSize) + s.Outstanding += s.pCount(seg, s.MaxPayloadSize) s.writeNext = seg.Next() } @@ -1044,21 +956,21 @@ func (s *sender) sendData() { } func (s *sender) enterRecovery() { - s.fr.active = true + s.FastRecovery.Active = true // Save state to reflect we're now in fast recovery. // // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. - s.sndCwnd = s.sndSsthresh + 3 - s.sackedOut = 0 - s.dupAckCount = 0 - s.fr.first = s.sndUna - s.fr.last = s.sndNxt - 1 - s.fr.maxCwnd = s.sndCwnd + s.outstanding - s.fr.highRxt = s.sndUna - s.fr.rescueRxt = s.sndUna - if s.ep.sackPermitted { + s.SndCwnd = s.Ssthresh + 3 + s.SackedOut = 0 + s.DupAckCount = 0 + s.FastRecovery.First = s.SndUna + s.FastRecovery.Last = s.SndNxt - 1 + s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding + s.FastRecovery.HighRxt = s.SndUna + s.FastRecovery.RescueRxt = s.SndUna + if s.ep.SACKPermitted { s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to @@ -1075,12 +987,12 @@ func (s *sender) enterRecovery() { } func (s *sender) leaveRecovery() { - s.fr.active = false - s.fr.maxCwnd = 0 - s.dupAckCount = 0 + s.FastRecovery.Active = false + s.FastRecovery.MaxCwnd = 0 + s.DupAckCount = 0 // Deflate cwnd. It had been artificially inflated when new dups arrived. - s.sndCwnd = s.sndSsthresh + s.SndCwnd = s.Ssthresh s.cc.PostRecovery() } @@ -1099,7 +1011,7 @@ func (s *sender) isAssignedSequenceNumber(seg *segment) bool { func (s *sender) SetPipe() { // If SACK isn't permitted or it is permitted but recovery is not active // then ignore pipe calculations. - if !s.ep.sackPermitted || !s.fr.active { + if !s.ep.SACKPermitted || !s.FastRecovery.Active { return } pipe := 0 @@ -1119,7 +1031,7 @@ func (s *sender) SetPipe() { // After initializing pipe to zero, the following steps are // taken for each octet 'S1' in the sequence space between // HighACK and HighData that has not been SACKed: - if !s1.sequenceNumber.LessThan(s.sndNxt) { + if !s1.sequenceNumber.LessThan(s.SndNxt) { break } if s.ep.scoreboard.IsSACKED(sb) { @@ -1138,20 +1050,20 @@ func (s *sender) SetPipe() { } // SetPipe(): // (b) If S1 <= HighRxt, Pipe is incremented by 1. - if s1.sequenceNumber.LessThanEq(s.fr.highRxt) { + if s1.sequenceNumber.LessThanEq(s.FastRecovery.HighRxt) { pipe++ } } } - s.outstanding = pipe + s.Outstanding = pipe } // shouldEnterRecovery returns true if the sender should enter fast recovery // based on dupAck count and sack scoreboard. // See RFC 6675 section 5. func (s *sender) shouldEnterRecovery() bool { - return s.dupAckCount >= nDupAckThreshold || - (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna)) + return s.DupAckCount >= nDupAckThreshold || + (s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.SndUna)) } // detectLoss is called when an ack is received and returns whether a loss is @@ -1163,24 +1075,24 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // If RACK is enabled and there is no reordering we should honor the // three duplicate ACK rule to enter recovery. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4 - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { - if s.rc.reorderSeen { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.rc.Reord { return false } } if !s.isDupAck(seg) { - s.dupAckCount = 0 + s.DupAckCount = 0 return false } - s.dupAckCount++ + s.DupAckCount++ // Do not enter fast recovery until we reach nDupAckThreshold or the // first unacknowledged byte is considered lost as per SACK scoreboard. if !s.shouldEnterRecovery() { // RFC 6675 Step 3. - s.fr.highRxt = s.sndUna - 1 + s.FastRecovery.HighRxt = s.SndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() s.state = tcpip.Disorder @@ -1196,8 +1108,8 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // Note that we only enter recovery when at least one more byte of data // beyond s.fr.last (the highest byte that was outstanding when fast // retransmit was last entered) is acked. - if !s.fr.last.LessThan(seg.ackNumber - 1) { - s.dupAckCount = 0 + if !s.FastRecovery.Last.LessThan(seg.ackNumber - 1) { + s.DupAckCount = 0 return false } s.cc.HandleLossDetected() @@ -1212,22 +1124,22 @@ func (s *sender) isDupAck(seg *segment) bool { // can leverage the SACK information to determine when an incoming ACK is a // "duplicate" (e.g., if the ACK contains previously unknown SACK // information). - if s.ep.sackPermitted && !seg.hasNewSACKInfo { + if s.ep.SACKPermitted && !seg.hasNewSACKInfo { return false } // (a) The receiver of the ACK has outstanding data. - return s.sndUna != s.sndNxt && + return s.SndUna != s.SndNxt && // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). - seg.ackNumber == s.sndUna && + seg.ackNumber == s.SndUna && // (e) the advertised window in the incoming acknowledgment equals the // advertised window in the last incoming acknowledgment. - s.sndWnd == seg.window + s.SndWnd == seg.window } // Iterate the writeList and update RACK for each segment which is newly acked @@ -1267,7 +1179,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } seg = seg.Next() } @@ -1322,18 +1234,18 @@ func checkDSACK(rcvdSeg *segment) bool { // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.rttMeasureTime)) - s.rttMeasureSeqNum = s.sndNxt + if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { + s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.RTTMeasureSeqNum = s.SndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { - s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.MaxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: @@ -1347,7 +1259,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.SndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.SndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) rcvdSeg.hasNewSACKInfo = true } @@ -1375,10 +1287,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ack := rcvdSeg.ackNumber fastRetransmit := false // Do not leave fast recovery, if the ACK is out of range. - if s.fr.active { + if s.FastRecovery.Active { // Leave fast recovery if it acknowledges all the data covered by // this fast recovery session. - if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) { + if (ack-1).InRange(s.SndUna, s.SndNxt) && s.FastRecovery.Last.LessThan(ack) { s.leaveRecovery() } } else { @@ -1392,28 +1304,28 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Stash away the current window size. - s.sndWnd = rcvdSeg.window + s.SndWnd = rcvdSeg.window // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously // unacknowledged segment. if s.zeroWindowProbing && rcvdSeg.window > 0 && - (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { + (ack == s.SndUna || (ack-1).InRange(s.SndUna, s.SndNxt)) { s.disableZeroWindowProbing() } // On receiving the ACK for the zero window probe, account for it and // skip trying to send any segment as we are still probing for // receive window to become non-zero. - if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna { + if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.SndUna { s.unackZeroWindowProbes-- return } // Ignore ack if it doesn't acknowledge any new data. - if (ack - 1).InRange(s.sndUna, s.sndNxt) { - s.dupAckCount = 0 + if (ack - 1).InRange(s.SndUna, s.SndNxt) { + s.DupAckCount = 0 // See : https://tools.ietf.org/html/rfc1323#section-3.3. // Specifically we should only update the RTO using TSEcr if the @@ -1423,7 +1335,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond @@ -1438,12 +1350,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // When an ack is received we must rearm the timer. // RFC 6298 5.3 s.probeTimer.disable() - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } // Remove all acknowledged data from the write list. - acked := s.sndUna.Size(ack) - s.sndUna = ack + acked := s.SndUna.Size(ack) + s.SndUna = ack // The remote ACK-ing at least 1 byte is an indication that we have a // full-duplex connection to the remote as the only way we will receive an @@ -1457,7 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } ackLeft := acked - originalOutstanding := s.outstanding + originalOutstanding := s.Outstanding for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1466,10 +1378,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg, s.maxPayloadSize) + prevCount := s.pCount(seg, s.MaxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) + s.Outstanding -= prevCount - s.pCount(seg, s.MaxPayloadSize) break } @@ -1478,7 +1390,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Update the RACK fields if SACK is enabled. - if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1488,10 +1400,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). - if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg, s.maxPayloadSize) + if !s.ep.SACKPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + s.Outstanding -= s.pCount(seg, s.MaxPayloadSize) } else { - s.sackedOut -= s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, s.MaxPayloadSize) } seg.decRef() ackLeft -= datalen @@ -1501,13 +1413,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.ep.updateSndBufferUsage(int(acked)) // Clear SACK information for all acked data. - s.ep.scoreboard.Delete(s.sndUna) + s.ep.scoreboard.Delete(s.SndUna) // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. - if !s.fr.active { - s.cc.Update(originalOutstanding - s.outstanding) - if s.fr.last.LessThan(s.sndUna) { + if !s.FastRecovery.Active { + s.cc.Update(originalOutstanding - s.Outstanding) + if s.FastRecovery.Last.LessThan(s.SndUna) { s.state = tcpip.Open // Update RACK when we are exiting fast or RTO // recovery as described in the RFC @@ -1522,16 +1434,16 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // It is possible for s.outstanding to drop below zero if we get // a retransmit timeout, reset outstanding to zero but later // get an ack that cover previously sent data. - if s.outstanding < 0 { - s.outstanding = 0 + if s.Outstanding < 0 { + s.Outstanding = 0 } s.SetPipe() // If all outstanding data was acknowledged the disable the timer. // RFC 6298 Rule 5.3 - if s.sndUna == s.sndNxt { - s.outstanding = 0 + if s.SndUna == s.SndNxt { + s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() @@ -1539,7 +1451,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { // Update RACK reorder window. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: @@ -1549,7 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the // reorder window. - if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active { + if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.FastRecovery.Active { // If any segment is marked as lost by // RACK, enter recovery and retransmit // the lost segments. @@ -1558,19 +1470,19 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { fastRetransmit = true } - if s.fr.active { + if s.FastRecovery.Active { s.rc.DoRecovery(nil, fastRetransmit) } } // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { + if s.FastRecovery.Active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { s.lr.DoRecovery(rcvdSeg, fastRetransmit) // When SACK is enabled data sending is governed by steps in // RFC 6675 Section 5 recovery steps A-C. // See: https://tools.ietf.org/html/rfc6675#section-5. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return } } @@ -1587,7 +1499,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() - if s.sndCwnd < s.sndSsthresh { + if s.SndCwnd < s.Ssthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } @@ -1601,11 +1513,11 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // then use the conservative timer described in RFC6675 Section 6.0, // otherwise follow the standard time described in RFC6298 Section 5.1. if err != nil && seg.data.Size() != 0 { - if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted { - s.resendTimer.enable(s.rto) + if s.FastRecovery.Active && seg.xmitCount > 1 && s.ep.SACKPermitted { + s.resendTimer.enable(s.RTO) } else { if !s.resendTimer.enabled() { - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1616,15 +1528,15 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.lastSendTime = time.Now() - if seq == s.rttMeasureSeqNum { - s.rttMeasureTime = s.lastSendTime + s.LastSendTime = time.Now() + if seq == s.RTTMeasureSeqNum { + s.RTTMeasureTime = s.LastSendTime } rcvNxt, rcvWnd := s.ep.rcv.getSendParams() // Remember the max sent ack. - s.maxSentAck = rcvNxt + s.MaxSentAck = rcvNxt return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index ba41cff6d..2f805d8ce 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -24,26 +24,6 @@ type unixTime struct { nano int64 } -// saveLastSendTime is invoked by stateify. -func (s *sender) saveLastSendTime() unixTime { - return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (s *sender) loadLastSendTime(unix unixTime) { - s.lastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (s *sender) saveRttMeasureTime() unixTime { - return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (s *sender) loadRttMeasureTime(unix unixTime) { - s.rttMeasureTime = time.Unix(unix.second, unix.nano) -} - // afterLoad is invoked by stateify. func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 5cdd5b588..c58361bc1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -33,6 +33,7 @@ const ( tsOptionSize = 12 maxTCPOptionSize = 40 mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload + latency = 5 * time.Millisecond ) func setStackRACKPermitted(t *testing.T, c *context.Context) { @@ -182,6 +183,9 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en for i := 0; i < numPackets; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload + // This delay is added to increase RTT as low RTT can cause TLP + // before sending ACK. + time.Sleep(latency) } return data @@ -479,7 +483,7 @@ func TestRACKOnePacketTailLoss(t *testing.T) { }{ // #3 was retransmitted as TLP. {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, // RTO should not have fired. {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, @@ -852,8 +856,8 @@ func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan return } - if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { - probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT) return } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 81f800cad..20c9761f2 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -160,12 +160,9 @@ func TestSackPermittedAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -235,12 +232,9 @@ func TestSackDisabledAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9c23469f2..9916182e3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" + tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" "gvisor.dev/gvisor/pkg/test/testutil" @@ -86,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for receive to be notified. select { case <-notifyRead: @@ -129,8 +130,8 @@ func TestGiveUpConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -144,8 +145,8 @@ func TestGiveUpConnect(t *testing.T) { // and stats updates. { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAborted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -158,6 +159,76 @@ func TestGiveUpConnect(t *testing.T) { } } +// Test for ICMP error handling without completing handshake. +func TestConnectICMPError(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventHUp) + defer wq.EventUnregister(&waitEntry) + + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } + } + + syn := c.GetPacket() + checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) + + wep := ep.(interface { + StopWork() + ResumeWork() + LastErrorLocked() tcpip.Error + }) + + // Stop the protocol loop, ensure that the ICMP error is processed and + // the last ICMP error is read before the loop is resumed. This sanity + // tests the handshake completion logic on ICMP errors. + wep.StopWork() + + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU) + + for { + if err := wep.LastErrorLocked(); err != nil { + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) + } + break + } + time.Sleep(time.Millisecond) + } + + wep.ResumeWork() + + <-notifyCh + + // The stack would have unregistered the endpoint because of the ICMP error. + // Expect a RST for any subsequent packets sent to the endpoint. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, + AckNum: c.IRS + 1, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + func TestConnectIncrementActiveConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -201,8 +272,8 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -392,7 +463,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -929,17 +1000,14 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { } // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) - } + rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} { err := c.EP.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) } } @@ -955,11 +1023,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { // when completing the handshake for a new TCP connection from a TCP // listening socket. It should be present in the sent TCP SYN-ACK segment. func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const ( - nonSynCookieAccepts = 2 - totalAccepts = 4 - mtu = 5000 - ) + const mtu = 5000 ips := []struct { name string @@ -1033,12 +1097,6 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { ip.createEP(c) - // Set the SynRcvd threshold to force a syn cookie based accept to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) } @@ -1048,13 +1106,17 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { t.Fatalf("Bind(%+v): %s:", bindAddr, err) } - if err := c.EP.Listen(totalAccepts); err != nil { - t.Fatalf("Listen(%d): %s:", totalAccepts, err) + backlog := 5 + // Keep the number of client requests twice to the backlog + // such that half of the connections do not use syncookies + // and the other half does. + clientConnects := backlog * 2 + + if err := c.EP.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s:", backlog, err) } - // The first nonSynCookieAccepts packets sent will trigger a gorooutine - // based accept. The rest will trigger a cookie based accept. - for i := 0; i < totalAccepts; i++ { + for i := 0; i < clientConnects; i++ { // Send a SYN requests. iss := seqnum.Value(i) srcPort := context.TestPort + uint16(i) @@ -1297,6 +1359,98 @@ func TestListenShutdown(t *testing.T) { )) } +var _ waiter.EntryCallback = (callback)(nil) + +type callback func(*waiter.Entry, waiter.EventMask) + +func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { + cb(entry, mask) +} + +func TestListenerReadinessOnEvent(t *testing.T) { + s := stack.New(stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + { + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + const id = 1 + if err := s.CreateNIC(id, ep); err != nil { + t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) + } + if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { + t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: header.IPv4EmptySubnet, NIC: id}, + }) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { + t.Fatalf("Bind(%s): %s", context.StackAddr, err) + } + const backlog = 1 + if err := ep.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s", backlog, err) + } + + address, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress(): %s", err) + } + + conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer conn.Close() + + events := make(chan waiter.EventMask) + // Scope `entry` to allow a binding of the same name below. + { + entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) { + events <- ep.Readiness(mask) + })} + wq.EventRegister(&entry, waiter.EventIn) + defer wq.EventUnregister(&entry) + } + + entry, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&entry, waiter.EventOut) + defer wq.EventUnregister(&entry) + + switch err := conn.Connect(address).(type) { + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect(%#v): %v", address, err) + } + + // Read at least one event. + got := <-events + for { + select { + case event := <-events: + got |= event + continue + case <-ch: + if want := waiter.ReadableEvents; got != want { + t.Errorf("observed events = %b, want %b", got, want) + } + } + break + } +} + // TestListenCloseWhileConnect tests for the listening endpoint to // drain the accept-queue when closed. This should reset all of the // pending connections that are waiting to be accepted. @@ -1459,8 +1613,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -1520,8 +1674,8 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -1993,9 +2147,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Bump up the receive buffer size such that, when the receive window grows, // the scaled window exceeds maxUint16. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true) // Keep the payload size < segment overhead and such that it is a multiple // of the window scaled value. This enables the test to perform equality @@ -2115,9 +2267,7 @@ func TestNoWindowShrinking(t *testing.T) { initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) // Now shrink the receive buffer to half its original size. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true) data := generateRandomPayload(t, rcvBufSize) // Send a payload of half the size of rcvBufSize. @@ -2373,9 +2523,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -2395,7 +2543,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -2447,9 +2595,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -2469,7 +2615,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3001,8 +3147,8 @@ func TestSetTTL(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3042,9 +3188,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3063,7 +3207,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3087,11 +3231,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, mtu) defer c.Cleanup() - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(0) + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -3119,7 +3261,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3185,9 +3327,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 3 - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) @@ -3196,8 +3336,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3315,8 +3455,8 @@ loop: case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } break loop case <-time.After(1 * time.Second): @@ -3366,8 +3506,8 @@ func TestSendOnResetConnection(t *testing.T) { var r bytes.Reader r.Reset(make([]byte, 10)) _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) } } @@ -4320,8 +4460,8 @@ func TestReadAfterClosedState(t *testing.T) { var buf bytes.Buffer { _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { + t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) } } } @@ -4365,8 +4505,8 @@ func TestReusePort(t *testing.T) { } { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } c.EP.Close() @@ -4411,11 +4551,7 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) - } - + s := ep.SocketOptions().GetReceiveBufferSize() if int(s) != v { t.Fatalf("got receive buffer size = %d, want = %d", s, v) } @@ -4521,10 +4657,7 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min/2. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(99, true) checkRecvBufferSize(t, ep, 200) ep.SocketOptions().SetSendBufferSize(149, true) @@ -4532,15 +4665,11 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true) // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) - // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } @@ -4665,8 +4794,8 @@ func TestSelfConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -4814,7 +4943,13 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("unknown address type: '%s'", candidateAddressType) } - start, end := s.PortRange() + const ( + start = 16000 + end = 16050 + ) + if err := s.SetPortRange(start, end); err != nil { + t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) + } for i := start; i <= end; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { t.Fatalf("Bind(%d) failed: %s", i, err) @@ -5387,7 +5522,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5404,7 +5539,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5416,7 +5551,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5445,8 +5580,8 @@ func TestListenBacklogFull(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a // non unicast IPv4 address are not accepted. func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpip.Address("\xe0\x00\x01\x02") - otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") + otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") subnet := context.StackAddrWithPrefix.Subnet() subnetBroadcastAddr := subnet.Broadcast() @@ -5557,8 +5692,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpiptestutil.MustParse6("ff0e::101") + otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") tests := []struct { name string @@ -5671,15 +5806,13 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %s", err) } // Send two SYN's the first one should get a SYN-ACK, the // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. + // the accept queue is full. irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5701,23 +5834,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. + // Now complete the previous connection. // Send ACK. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5728,13 +5845,26 @@ func TestListenSynRcvdQueueFull(t *testing.T) { RcvWnd: 30000, }) - // Try to accept the connections in the backlog. + // Verify if that is delivered to the accept queue. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) + <-ch + + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(889), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + // Try to accept the connections in the backlog. newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5764,11 +5894,6 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.TCPSynRcvdCountThresholdOption(1) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - // Create TCP endpoint. var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -5781,9 +5906,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { t.Fatalf("Bind failed: %s", err) } - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + // Test for SynCookies usage after filling up the backlog. + if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %s", err) } @@ -5811,7 +5935,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5827,7 +5951,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5966,7 +6090,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { t.Fatalf("Accept failed: %s", err) } - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) @@ -6034,7 +6158,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6104,7 +6228,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6156,7 +6280,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6174,8 +6298,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } { err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { + t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) } } // Listening endpoint remains in listen state. @@ -6295,7 +6419,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } } @@ -6426,7 +6550,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } totalCopied += res.Count @@ -6618,7 +6742,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6737,7 +6861,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6844,7 +6968,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6934,7 +7058,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7008,7 +7132,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7158,7 +7282,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7553,8 +7677,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS - c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - + c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( @@ -7590,8 +7713,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Send data. This should result in an acceptable endpoint. @@ -7649,8 +7772,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Sleep for a little of the tcpDeferAccept timeout. diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 2949588ce..1deb1fe4d 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -139,9 +139,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -202,9 +202,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e73f90bb0..53efecc5a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -331,8 +331,8 @@ func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) + if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize) } checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) @@ -757,9 +757,7 @@ func (c *Context) Create(epRcvBuf int) { } if epRcvBuf != -1 { - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */) } } @@ -1216,9 +1214,9 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { if enable { - c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.HWGSOSupported } else { - c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.GSONotSupported } } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 153e8c950..dd5c910ae 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -56,6 +56,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 956da0e0c..f7dd50d35 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "fmt" "io" "sync/atomic" @@ -89,12 +88,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList udpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -144,6 +142,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // +stateify savable @@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - rcvBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } return e @@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.mu.Lock() e.sendTOS = uint8(v) e.mu.Unlock() - - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) - } - - if v < rs.Min { - v = rs.Min - } - if v > rs.Max { - v = rs.Max - } - - e.mu.Lock() - e.rcvBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.mu.Lock() v := int(e.ttl) @@ -872,7 +848,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: ProtocolNumber, TTL: ttl, TOS: tos, @@ -1255,20 +1231,29 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // verifyChecksum verifies the checksum unless RX checksum offload is enabled. -// On IPv4, UDP checksum is optional, and a zero value means the transmitter -// omitted the checksum generation (RFC768). -// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { - if !pkt.RXTransportChecksumValidated && - (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { - netHdr := pkt.Network() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) - for _, v := range pkt.Data().Views() { - xsum = header.Checksum(v, xsum) - } - return hdr.CalculateChecksum(xsum) == 0xffff + if pkt.RXTransportChecksumValidated { + return true + } + + // On IPv4, UDP checksum is optional, and a zero value means the transmitter + // omitted the checksum generation, as per RFC 768: + // + // An all zero transmitted checksum value means that the transmitter + // generated no checksum (for debugging or for higher level protocols that + // don't care). + // + // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1: + // + // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP + // checksum is not optional. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 { + return true } - return true + + netHdr := pkt.Network() + payloadChecksum := pkt.Data().AsRange().Checksum() + return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum) } // HandlePacket is called by the stack when new packets arrive to this transport @@ -1284,7 +1269,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } if !verifyChecksum(hdr, pkt) { - // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() return @@ -1302,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -1436,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 21a6aa460..4aba68b21 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() + e.mu.Lock() defer e.mu.Unlock() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 77ca70a04..dc2e3f493 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -34,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -2364,7 +2365,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } ipv4Subnet := ipv4Addr.Subnet() ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4Gateway := testutil.MustParse4("192.168.1.1") ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ Address: "\xc0\xa8\x01\x3a", PrefixLen: 31, diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 7f983a0b3..366f068e3 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -36,8 +36,8 @@ go_test( tags = [ # Requires docker and runsc to be configured before test runs. # Also requires the test to be run as root. - "manual", "local", + "manual", ], visibility = ["//:sandbox"], ) diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 41fcf4978..06152a444 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -434,7 +434,14 @@ func (c *Container) Wait(ctx context.Context) error { select { case err := <-errChan: return err - case <-statusChan: + case res := <-statusChan: + if res.StatusCode != 0 { + var msg string + if res.Error != nil { + msg = res.Error.Message + } + return fmt.Errorf("container returned non-zero status: %d, msg: %q", res.StatusCode, msg) + } return nil } } diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go index 4855a52fc..12fe98b16 100644 --- a/pkg/test/dockerutil/profile.go +++ b/pkg/test/dockerutil/profile.go @@ -82,10 +82,15 @@ func (p *profile) createProcess(c *Container) error { } // The root directory of this container's runtime. - root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + rootDir := fmt.Sprintf("/var/run/docker/runtime-%s/moby", c.runtime) + if _, err := os.Stat(rootDir); os.IsNotExist(err) { + // In docker v20+, due to https://github.com/moby/moby/issues/42345 the + // rootDir seems to always be the following. + rootDir = "/var/run/docker/runtime-runc/moby" + } - // Format is `runsc --root=rootdir debug --profile-*=file --duration=24h containerID`. - args := []string{root, "debug"} + // Format is `runsc --root=rootDir debug --profile-*=file --duration=24h containerID`. + args := []string{fmt.Sprintf("--root=%s", rootDir), "debug"} for _, profileArg := range p.Types { outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg)) args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath)) diff --git a/runsc/BUILD b/runsc/BUILD index 3b91b984a..7a7dcc8d5 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -9,6 +9,7 @@ go_binary( "version.go", ], pure = True, + tags = ["staging"], visibility = [ "//visibility:public", ], @@ -49,5 +50,4 @@ sh_test( srcs = ["version_test.sh"], args = ["$(location :runsc)"], data = [":runsc"], - tags = ["noguitar"], ) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 67307ab3c..d51347fe1 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/cleanup", "//pkg/context", "//pkg/control/server", + "//pkg/coverage", "//pkg/cpuid", "//pkg/eventchannel", "//pkg/fd", @@ -37,6 +38,7 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/memutil", + "//pkg/metric", "//pkg/rand", "//pkg/refs", "//pkg/refsvfs2", @@ -57,6 +59,7 @@ go_library( "//pkg/sentry/fs/tmpfs", "//pkg/sentry/fs/tty", "//pkg/sentry/fs/user", + "//pkg/sentry/fsimpl/cgroupfs", "//pkg/sentry/fsimpl/devpts", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/fuse", @@ -66,6 +69,7 @@ go_library( "//pkg/sentry/fsimpl/proc", "//pkg/sentry/fsimpl/sys", "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/fsimpl/verity", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel:uncaught_signal_go_proto", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 1ae76d7d7..9b270cbf2 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -400,9 +400,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Set up the restore environment. ctx := k.SupervisorContext() - mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled) + mntr := newContainerMounter(&cm.l.root, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled) if kernel.VFS2Enabled { - ctx, err = mntr.configureRestore(ctx, cm.l.root.conf) + ctx, err = mntr.configureRestore(ctx) if err != nil { return fmt.Errorf("configuring filesystem restore: %v", err) } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 32adde643..bf4a41f77 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/gofer" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/fs/user" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" gofervfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer" @@ -103,17 +104,22 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name // compileMounts returns the supported mounts from the mount spec, adding any // mandatory mounts that are required by the OCI specification. -func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount { +func compileMounts(spec *specs.Spec, conf *config.Config, vfs2Enabled bool) []specs.Mount { // Keep track of whether proc and sys were mounted. var procMounted, sysMounted, devMounted, devptsMounted bool var mounts []specs.Mount // Mount all submounts from the spec. for _, m := range spec.Mounts { - if !vfs2Enabled && !specutils.IsVFS1SupportedDevMount(m) { + if !specutils.IsSupportedDevMount(m, vfs2Enabled) { log.Warningf("ignoring dev mount at %q", m.Destination) continue } + // Unconditionally drop any cgroupfs mounts. If requested, we'll add our + // own below. + if m.Type == cgroupfs.Name { + continue + } switch filepath.Clean(m.Destination) { case "/proc": procMounted = true @@ -132,6 +138,24 @@ func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount { // Mount proc and sys even if the user did not ask for it, as the spec // says we SHOULD. var mandatoryMounts []specs.Mount + + if conf.Cgroupfs { + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: tmpfsvfs2.Name, + Destination: "/sys/fs/cgroup", + }) + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: cgroupfs.Name, + Destination: "/sys/fs/cgroup/memory", + Options: []string{"memory"}, + }) + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: cgroupfs.Name, + Destination: "/sys/fs/cgroup/cpu", + Options: []string{"cpu"}, + }) + } + if !procMounted { mandatoryMounts = append(mandatoryMounts, specs.Mount{ Type: procvfs2.Name, @@ -208,7 +232,7 @@ func parseMountOption(opt string, allowedKeys ...string) (bool, error) { // mountDevice returns a device string based on the fs type and target // of the mount. -func mountDevice(m specs.Mount) string { +func mountDevice(m *specs.Mount) string { if m.Type == bind { // Make a device string that includes the target, which is consistent across // S/R and uniquely identifies the connection. @@ -232,6 +256,8 @@ func mountFlags(opts []string) fs.MountSourceFlags { mf.NoAtime = true case "noexec": mf.NoExec = true + case "bind", "rbind": + // These are the same as a mount with type="bind". default: log.Warningf("ignoring unknown mount option %q", o) } @@ -248,6 +274,10 @@ func isSupportedMountFlag(fstype, opt string) bool { ok, err := parseMountOption(opt, tmpfsAllowedData...) return ok && err == nil } + if fstype == cgroupfs.Name { + ok, err := parseMountOption(opt, cgroupfs.SupportedMountOptions...) + return ok && err == nil + } return false } @@ -458,9 +488,9 @@ func (m *mountHint) isSupported() bool { // For now enforce that all options are the same. Once bind mount is properly // supported, then we should ensure the master is less restrictive than the // container, e.g. master can be 'rw' while container mounts as 'ro'. -func (m *mountHint) checkCompatible(mount specs.Mount) error { +func (m *mountHint) checkCompatible(mount *specs.Mount) error { // Remove options that don't affect to mount's behavior. - masterOpts := filterUnsupportedOptions(m.mount) + masterOpts := filterUnsupportedOptions(&m.mount) replicaOpts := filterUnsupportedOptions(mount) if len(masterOpts) != len(replicaOpts) { @@ -484,7 +514,7 @@ func (m *mountHint) fileAccessType() config.FileAccessType { return config.FileAccessShared } -func filterUnsupportedOptions(mount specs.Mount) []string { +func filterUnsupportedOptions(mount *specs.Mount) []string { rv := make([]string, 0, len(mount.Options)) for _, o := range mount.Options { if isSupportedMountFlag(mount.Type, o) { @@ -548,7 +578,7 @@ func newPodMountHints(spec *specs.Spec) (*podMountHints, error) { return &podMountHints{mounts: mnts}, nil } -func (p *podMountHints) findMount(mount specs.Mount) *mountHint { +func (p *podMountHints) findMount(mount *specs.Mount) *mountHint { for _, m := range p.mounts { if m.mount.Source == mount.Source { return m @@ -572,11 +602,11 @@ type containerMounter struct { hints *podMountHints } -func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter { +func newContainerMounter(info *containerInfo, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter { return &containerMounter{ - root: spec.Root, - mounts: compileMounts(spec, vfs2Enabled), - fds: fdDispenser{fds: goferFDs}, + root: info.spec.Root, + mounts: compileMounts(info.spec, info.conf, vfs2Enabled), + fds: fdDispenser{fds: info.goferFDs}, k: k, hints: hints, } @@ -651,7 +681,8 @@ func (c *containerMounter) mountSubmounts(ctx context.Context, conf *config.Conf root := mns.Root() defer root.DecRef(ctx) - for _, m := range c.mounts { + for i := range c.mounts { + m := &c.mounts[i] log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options) if hint := c.hints.findMount(m); hint != nil && hint.isSupported() { if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil { @@ -686,7 +717,7 @@ func (c *containerMounter) checkDispenser() error { func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *config.Config, hint *mountHint) (*fs.Inode, error) { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. - fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, hint.mount) + fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, &hint.mount) if err != nil { return nil, err } @@ -706,7 +737,7 @@ func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *config.C mf.ReadOnly = true } - inode, err := filesystem.Mount(ctx, mountDevice(hint.mount), mf, strings.Join(opts, ","), nil) + inode, err := filesystem.Mount(ctx, mountDevice(&hint.mount), mf, strings.Join(opts, ","), nil) if err != nil { return nil, fmt.Errorf("creating mount %q: %v", hint.name, err) } @@ -768,13 +799,14 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *config.Con // getMountNameAndOptions retrieves the fsName, opts, and useOverlay values // used for mounts. -func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.Mount) (string, []string, bool, error) { +func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m *specs.Mount) (string, []string, bool, error) { + specutils.MaybeConvertToBindMount(m) + var ( fsName string opts []string useOverlay bool ) - switch m.Type { case devpts.Name, devtmpfs.Name, procvfs2.Name, sysvfs2.Name: fsName = m.Type @@ -795,14 +827,20 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M opts = p9MountData(fd, c.getMountAccessType(conf, m), conf.VFS2) // If configured, add overlay to all writable mounts. useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly - + case cgroupfs.Name: + fsName = m.Type + var err error + opts, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...) + if err != nil { + return "", nil, false, err + } default: log.Warningf("ignoring unknown filesystem type %q", m.Type) } return fsName, opts, useOverlay, nil } -func (c *containerMounter) getMountAccessType(conf *config.Config, mount specs.Mount) config.FileAccessType { +func (c *containerMounter) getMountAccessType(conf *config.Config, mount *specs.Mount) config.FileAccessType { if hint := c.hints.findMount(mount); hint != nil { return hint.fileAccessType() } @@ -813,7 +851,7 @@ func (c *containerMounter) getMountAccessType(conf *config.Config, mount specs.M // be readonly, a lower ramfs overlay is added to create the mount point dir. // Another overlay is added with tmpfs on top if Config.Overlay is true. // 'm.Destination' must be an absolute path with '..' and symlinks resolved. -func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent, m specs.Mount) error { +func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent, m *specs.Mount) error { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) @@ -887,7 +925,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Confi // mountSharedSubmount binds mount to a previously mounted volume that is shared // among containers in the same pod. -func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount specs.Mount, source *mountHint) error { +func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.MountNamespace, root *fs.Dirent, mount *specs.Mount, source *mountHint) error { if err := source.checkCompatible(mount); err != nil { return err } @@ -912,7 +950,7 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun // addRestoreMount adds a mount to the MountSources map used for restoring a // checkpointed container. -func (c *containerMounter) addRestoreMount(conf *config.Config, renv *fs.RestoreEnvironment, m specs.Mount) error { +func (c *containerMounter) addRestoreMount(conf *config.Config, renv *fs.RestoreEnvironment, m *specs.Mount) error { fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) if err != nil { return err @@ -960,7 +998,8 @@ func (c *containerMounter) createRestoreEnvironment(conf *config.Config) (*fs.Re // Add submounts. var tmpMounted bool - for _, m := range c.mounts { + for i := range c.mounts { + m := &c.mounts[i] if err := c.addRestoreMount(conf, renv, m); err != nil { return nil, err } @@ -975,7 +1014,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *config.Config) (*fs.Re Type: tmpfsvfs2.Name, Destination: "/tmp", } - if err := c.addRestoreMount(conf, renv, tmpMount); err != nil { + if err := c.addRestoreMount(conf, renv, &tmpMount); err != nil { return nil, err } } @@ -1034,7 +1073,7 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *config.Config, mn // another user. This is normally done for /tmp. Options: []string{"mode=01777"}, } - return c.mountSubmount(ctx, conf, mns, root, tmpMount) + return c.mountSubmount(ctx, conf, mns, root, &tmpMount) default: return err diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go index b4f12d034..09ffda628 100644 --- a/runsc/boot/fs_test.go +++ b/runsc/boot/fs_test.go @@ -244,7 +244,7 @@ func TestGetMountAccessType(t *testing.T) { } mounter := containerMounter{hints: podHints} conf := &config.Config{FileAccessMounts: config.FileAccessShared} - if got := mounter.getMountAccessType(conf, specs.Mount{Source: source}); got != tst.want { + if got := mounter.getMountAccessType(conf, &specs.Mount{Source: source}); got != tst.want { t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got) } }) diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 774621970..10f2d3d35 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -29,10 +29,12 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/coverage" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/refsvfs2" @@ -216,6 +218,8 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("setting up memory usage: %v", err) } + metric.CreateSentryMetrics() + // Is this a VFSv2 kernel? if args.Conf.VFS2 { kernel.VFS2Enabled = true @@ -226,6 +230,33 @@ func New(args Args) (*Loader, error) { vfs2.Override() } + // Make host FDs stable between invocations. Host FDs must map to the exact + // same number when the sandbox is restored. Otherwise the wrong FD will be + // used. + info := containerInfo{} + newfd := startingStdioFD + + for _, stdioFD := range args.StdioFDs { + // Check that newfd is unused to avoid clobbering over it. + if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) { + if err != nil { + return nil, fmt.Errorf("error checking for FD (%d) conflict: %w", newfd, err) + } + return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd) + } + + err := unix.Dup3(stdioFD, newfd, unix.O_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("dup3 of stdios failed: %w", err) + } + info.stdioFDs = append(info.stdioFDs, fd.New(newfd)) + _ = unix.Close(stdioFD) + newfd++ + } + for _, goferFD := range args.GoferFDs { + info.goferFDs = append(info.goferFDs, fd.New(goferFD)) + } + // Create kernel and platform. p, err := createPlatform(args.Conf, args.Device) if err != nil { @@ -345,6 +376,7 @@ func New(args Args) (*Loader, error) { if err != nil { return nil, fmt.Errorf("creating init process for root container: %v", err) } + info.procArgs = procArgs if err := initCompatLogs(args.UserLogFD); err != nil { return nil, fmt.Errorf("initializing compat logs: %v", err) @@ -355,6 +387,9 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("creating pod mount hints: %v", err) } + info.conf = args.Conf + info.spec = args.Spec + if kernel.VFS2Enabled { // Set up host mount that will be used for imported fds. hostFilesystem, err := hostvfs2.NewFilesystem(k.VFS()) @@ -369,37 +404,6 @@ func New(args Args) (*Loader, error) { k.SetHostMount(hostMount) } - info := containerInfo{ - conf: args.Conf, - spec: args.Spec, - procArgs: procArgs, - } - - // Make host FDs stable between invocations. Host FDs must map to the exact - // same number when the sandbox is restored. Otherwise the wrong FD will be - // used. - newfd := startingStdioFD - for _, stdioFD := range args.StdioFDs { - // Check that newfd is unused to avoid clobbering over it. - if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) { - if err != nil { - return nil, fmt.Errorf("error checking for FD (%d) conflict: %w", newfd, err) - } - return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd) - } - - err := unix.Dup3(stdioFD, newfd, unix.O_CLOEXEC) - if err != nil { - return nil, fmt.Errorf("dup3 of stdios failed: %w", err) - } - info.stdioFDs = append(info.stdioFDs, fd.New(newfd)) - _ = unix.Close(stdioFD) - newfd++ - } - for _, goferFD := range args.GoferFDs { - info.goferFDs = append(info.goferFDs, fd.New(goferFD)) - } - eid := execID{cid: args.ID} l := &Loader{ k: k, @@ -491,10 +495,6 @@ func (l *Loader) Destroy() { // save/restore. l.k.Release() - // All sentry-created resources should have been released at this point; - // check for reference leaks. - refsvfs2.DoLeakCheck() - // In the success case, stdioFDs and goferFDs will only contain // released/closed FDs that ownership has been passed over to host FDs and // gofer sessions. Close them here in case of failure. @@ -752,7 +752,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn // Setup the child container file system. l.startGoferMonitor(cid, info.goferFDs) - mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints, kernel.VFS2Enabled) + mntr := newContainerMounter(info, l.k, l.mountHints, kernel.VFS2Enabled) if root { if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil { return nil, nil, nil, err @@ -1000,6 +1000,15 @@ func (l *Loader) waitContainer(cid string, waitStatus *uint32) error { // consider the container exited. ws := l.wait(tg) *waitStatus = ws + + // Check for leaks and write coverage report after the root container has + // exited. This guarantees that the report is written in cases where the + // sandbox is killed by a signal after the ContainerWait request is completed. + if l.root.procArgs.ContainerID == cid { + // All sentry-created resources should have been released at this point. + refsvfs2.DoLeakCheck() + coverage.Report() + } return nil } diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 8b39bc59a..93c476971 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -439,7 +439,13 @@ func TestCreateMountNamespace(t *testing.T) { } defer cleanup() - mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}, false /* vfs2Enabled */) + info := containerInfo{ + conf: conf, + spec: &tc.spec, + goferFDs: []*fd.FD{fd.New(sandEnd)}, + } + + mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */) mns, err := mntr.createMountNamespace(ctx, conf) if err != nil { t.Fatalf("failed to create mount namespace: %v", err) @@ -479,7 +485,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { defer l.Destroy() defer loaderCleanup() - mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints, true /* vfs2Enabled */) + mntr := newContainerMounter(&l.root, l.k, l.mountHints, true /* vfs2Enabled */) if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil { t.Fatalf("failed process hints: %v", err) } @@ -702,7 +708,12 @@ func TestRestoreEnvironment(t *testing.T) { for _, ioFD := range tc.ioFDs { ioFDs = append(ioFDs, fd.New(ioFD)) } - mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}, false /* vfs2Enabled */) + info := containerInfo{ + conf: conf, + spec: tc.spec, + goferFDs: ioFDs, + } + mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */) actualRenv, err := mntr.createRestoreEnvironment(conf) if !tc.errorExpected && err != nil { t.Fatalf("could not create restore environment for test:%s", tc.name) diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 9b3dacf46..c1828bd3d 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -16,6 +16,7 @@ package boot import ( "fmt" + "path" "sort" "strings" @@ -29,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/devices/ttydev" "gvisor.dev/gvisor/pkg/sentry/devices/tundev" "gvisor.dev/gvisor/pkg/sentry/fs/user" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/fuse" @@ -37,12 +39,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/verity" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/runsc/config" + "gvisor.dev/gvisor/runsc/specutils" ) func registerFilesystems(k *kernel.Kernel) error { @@ -50,6 +54,10 @@ func registerFilesystems(k *kernel.Kernel) error { creds := auth.NewRootCredentials(k.RootUserNamespace()) vfsObj := k.VFS() + vfsObj.MustRegisterFilesystemType(cgroupfs.Name, &cgroupfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, // TODO(b/29356795): Users may mount this once the terminals are in a @@ -60,6 +68,10 @@ func registerFilesystems(k *kernel.Kernel) error { AllowUserMount: true, AllowUserList: true, }) + vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) vfsObj.MustRegisterFilesystemType(gofer.Name, &gofer.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, }) @@ -79,9 +91,9 @@ func registerFilesystems(k *kernel.Kernel) error { AllowUserMount: true, AllowUserList: true, }) - vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, + vfsObj.MustRegisterFilesystemType(verity.Name, &verity.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, + AllowUserMount: true, }) // Setup files in devtmpfs. @@ -351,33 +363,33 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config. for i := range mounts { submount := &mounts[i] - log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options) + log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.mount.Source, submount.mount.Destination, submount.mount.Type, submount.mount.Options) var ( mnt *vfs.Mount err error ) - if hint := c.hints.findMount(submount.Mount); hint != nil && hint.isSupported() { - mnt, err = c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint) + if hint := c.hints.findMount(submount.mount); hint != nil && hint.isSupported() { + mnt, err = c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.mount, hint) if err != nil { - return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.Destination, err) + return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.mount.Destination, err) } } else { mnt, err = c.mountSubmountVFS2(ctx, conf, mns, creds, submount) if err != nil { - return fmt.Errorf("mount submount %q: %w", submount.Destination, err) + return fmt.Errorf("mount submount %q: %w", submount.mount.Destination, err) } } if mnt != nil && mnt.ReadOnly() { // Switch to ReadWrite while we setup submounts. if err := c.k.VFS().SetMountReadOnly(mnt, false); err != nil { - return fmt.Errorf("failed to set mount at %q readwrite: %w", submount.Destination, err) + return fmt.Errorf("failed to set mount at %q readwrite: %w", submount.mount.Destination, err) } // Restore back to ReadOnly at the end. defer func() { if err := c.k.VFS().SetMountReadOnly(mnt, true); err != nil { - panic(fmt.Sprintf("failed to restore mount at %q back to readonly: %v", submount.Destination, err)) + panic(fmt.Sprintf("failed to restore mount at %q back to readonly: %v", submount.mount.Destination, err)) } }() } @@ -390,8 +402,8 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config. } type mountAndFD struct { - specs.Mount - fd int + mount *specs.Mount + fd int } func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { @@ -399,15 +411,18 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { // undocumented assumption that FDs are dispensed in the order in which // they are required by mounts. var mounts []mountAndFD - for _, m := range c.mounts { - fd := -1 + for i := range c.mounts { + m := &c.mounts[i] + specutils.MaybeConvertToBindMount(m) + // Only bind mounts use host FDs; see // containerMounter.getMountNameAndOptionsVFS2. + fd := -1 if m.Type == bind { fd = c.fds.remove() } mounts = append(mounts, mountAndFD{ - Mount: m, + mount: m, fd: fd, }) } @@ -417,7 +432,7 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { // Sort the mounts so that we don't place children before parents. sort.Slice(mounts, func(i, j int) bool { - return len(mounts[i].Destination) < len(mounts[j].Destination) + return len(mounts[i].mount.Destination) < len(mounts[j].mount.Destination) }) return mounts, nil @@ -433,16 +448,16 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.C return nil, nil } - if err := c.makeMountPoint(ctx, creds, mns, submount.Destination); err != nil { - return nil, fmt.Errorf("creating mount point %q: %w", submount.Destination, err) + if err := c.makeMountPoint(ctx, creds, mns, submount.mount.Destination); err != nil { + return nil, fmt.Errorf("creating mount point %q: %w", submount.mount.Destination, err) } if useOverlay { - log.Infof("Adding overlay on top of mount %q", submount.Destination) + log.Infof("Adding overlay on top of mount %q", submount.mount.Destination) var cleanup func() opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) if err != nil { - return nil, fmt.Errorf("mounting volume with overlay at %q: %w", submount.Destination, err) + return nil, fmt.Errorf("mounting volume with overlay at %q: %w", submount.mount.Destination, err) } defer cleanup() fsName = overlay.Name @@ -454,26 +469,34 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.C target := &vfs.PathOperation{ Root: root, Start: root, - Path: fspath.Parse(submount.Destination), + Path: fspath.Parse(submount.mount.Destination), } mnt, err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts) if err != nil { - return nil, fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts) + return nil, fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.mount.Destination, submount.mount.Type, err, opts) } - log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts.GetFilesystemOptions.Data) + log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.mount.Source, submount.mount.Destination, submount.mount.Type, opts.GetFilesystemOptions.Data) return mnt, nil } // getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values // used for mounts. func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mountAndFD) (string, *vfs.MountOptions, bool, error) { - fsName := m.Type + fsName := m.mount.Type useOverlay := false - var data []string - var iopts interface{} + var ( + data []string + internalData interface{} + ) + + verityData, verityOpts, verityRequested, remainingMOpts, err := parseVerityMountOptions(m.mount.Options) + if err != nil { + return "", nil, false, err + } + m.mount.Options = remainingMOpts // Find filesystem name and FS specific data field. - switch m.Type { + switch m.mount.Type { case devpts.Name, devtmpfs.Name, proc.Name, sys.Name: // Nothing to do. @@ -482,7 +505,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo case tmpfs.Name: var err error - data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...) + data, err = parseAndFilterOptions(m.mount.Options, tmpfsAllowedData...) if err != nil { return "", nil, false, err } @@ -494,28 +517,35 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo // but unlikely to be correct in this context. return "", nil, false, fmt.Errorf("9P mount requires a connection FD") } - data = p9MountData(m.fd, c.getMountAccessType(conf, m.Mount), true /* vfs2 */) - iopts = gofer.InternalFilesystemOptions{ - UniqueID: m.Destination, + data = p9MountData(m.fd, c.getMountAccessType(conf, m.mount), true /* vfs2 */) + internalData = gofer.InternalFilesystemOptions{ + UniqueID: m.mount.Destination, } // If configured, add overlay to all writable mounts. - useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly + useOverlay = conf.Overlay && !mountFlags(m.mount.Options).ReadOnly + + case cgroupfs.Name: + var err error + data, err = parseAndFilterOptions(m.mount.Options, cgroupfs.SupportedMountOptions...) + if err != nil { + return "", nil, false, err + } default: - log.Warningf("ignoring unknown filesystem type %q", m.Type) + log.Warningf("ignoring unknown filesystem type %q", m.mount.Type) return "", nil, false, nil } opts := &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ Data: strings.Join(data, ","), - InternalData: iopts, + InternalData: internalData, }, InternalMount: true, } - for _, o := range m.Options { + for _, o := range m.mount.Options { switch o { case "rw": opts.ReadOnly = false @@ -525,14 +555,82 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo opts.Flags.NoATime = true case "noexec": opts.Flags.NoExec = true + case "bind", "rbind": + // These are the same as a mount with type="bind". default: log.Warningf("ignoring unknown mount option %q", o) } } + if verityRequested { + verityData = verityData + "root_name=" + path.Base(m.mount.Destination) + verityOpts.LowerName = fsName + verityOpts.LowerGetFSOptions = opts.GetFilesystemOptions + fsName = verity.Name + opts = &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: verityData, + InternalData: verityOpts, + }, + InternalMount: true, + } + } + return fsName, opts, useOverlay, nil } +func parseKeyValue(s string) (string, string, bool) { + tokens := strings.SplitN(s, "=", 2) + if len(tokens) < 2 { + return "", "", false + } + return strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1]), true +} + +// parseAndFilterOptions scans the provided mount options for verity-related +// mount options. It returns the parsed set of verity mount options, as well as +// the filtered set of mount options unrelated to verity. +func parseVerityMountOptions(mopts []string) (string, verity.InternalFilesystemOptions, bool, []string, error) { + nonVerity := []string{} + found := false + var rootHash string + verityOpts := verity.InternalFilesystemOptions{ + Action: verity.PanicOnViolation, + } + for _, o := range mopts { + if !strings.HasPrefix(o, "verity.") { + nonVerity = append(nonVerity, o) + continue + } + + k, v, ok := parseKeyValue(o) + if !ok { + return "", verityOpts, found, nonVerity, fmt.Errorf("invalid verity mount option with no value: %q", o) + } + + found = true + switch k { + case "verity.roothash": + rootHash = v + case "verity.action": + switch v { + case "error": + verityOpts.Action = verity.ErrorOnViolation + case "panic": + verityOpts.Action = verity.PanicOnViolation + default: + log.Warningf("Invalid verity action %q", v) + verityOpts.Action = verity.PanicOnViolation + } + default: + return "", verityOpts, found, nonVerity, fmt.Errorf("unknown verity mount option: %q", k) + } + } + verityOpts.AllowRuntimeEnable = len(rootHash) == 0 + verityData := "root_hash=" + rootHash + "," + return verityData, verityOpts, found, nonVerity, nil +} + // mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so. // Technically we don't have to mount tmpfs at /tmp, as we could just rely on // the host /tmp, but this is a nice optimization, and fixes some apps that call @@ -594,7 +692,7 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config // another user. This is normally done for /tmp. Options: []string{"mode=01777"}, } - _, err := c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount}) + _, err := c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{mount: &tmpMount}) return err case syserror.ENOTDIR: @@ -633,7 +731,7 @@ func (c *containerMounter) processHintsVFS2(conf *config.Config, creds *auth.Cre func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *config.Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. - mntFD := &mountAndFD{Mount: hint.mount} + mntFD := &mountAndFD{mount: &hint.mount} fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, mntFD) if err != nil { return nil, err @@ -643,11 +741,11 @@ func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *conf } if useOverlay { - log.Infof("Adding overlay on top of shared mount %q", mntFD.Destination) + log.Infof("Adding overlay on top of shared mount %q", mntFD.mount.Destination) var cleanup func() opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) if err != nil { - return nil, fmt.Errorf("mounting shared volume with overlay at %q: %w", mntFD.Destination, err) + return nil, fmt.Errorf("mounting shared volume with overlay at %q: %w", mntFD.mount.Destination, err) } defer cleanup() fsName = overlay.Name @@ -658,14 +756,14 @@ func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *conf // mountSharedSubmount binds mount to a previously mounted volume that is shared // among containers in the same pod. -func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) (*vfs.Mount, error) { +func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount *specs.Mount, source *mountHint) (*vfs.Mount, error) { if err := source.checkCompatible(mount); err != nil { return nil, err } // Ignore data and useOverlay because these were already applied to // the master mount. - _, opts, _, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount}) + _, opts, _, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{mount: mount}) if err != nil { return nil, err } @@ -718,7 +816,7 @@ func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Crede // configureRestore returns an updated context.Context including filesystem // state used by restore defined by conf. -func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Config) (context.Context, error) { +func (c *containerMounter) configureRestore(ctx context.Context) (context.Context, error) { fdmap := make(map[string]int) fdmap["/"] = c.fds.remove() mounts, err := c.prepareMountsVFS2() @@ -728,7 +826,7 @@ func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Co for i := range c.mounts { submount := &mounts[i] if submount.fd >= 0 { - fdmap[submount.Destination] = submount.fd + fdmap[submount.mount.Destination] = submount.fd } } return context.WithValue(ctx, gofer.CtxRestoreServerFDMap, fdmap), nil diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index 438b7ef3e..335e46499 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -40,23 +40,24 @@ const ( cgroupRoot = "/sys/fs/cgroup" ) -var controllers = map[string]config{ - "blkio": {ctrlr: &blockIO{}}, - "cpu": {ctrlr: &cpu{}}, - "cpuset": {ctrlr: &cpuSet{}}, - "hugetlb": {ctrlr: &hugeTLB{}, optional: true}, - "memory": {ctrlr: &memory{}}, - "net_cls": {ctrlr: &networkClass{}}, - "net_prio": {ctrlr: &networkPrio{}}, - "pids": {ctrlr: &pids{}}, +var controllers = map[string]controller{ + "blkio": &blockIO{}, + "cpu": &cpu{}, + "cpuset": &cpuSet{}, + "hugetlb": &hugeTLB{}, + "memory": &memory{}, + "net_cls": &networkClass{}, + "net_prio": &networkPrio{}, + "pids": &pids{}, // These controllers either don't have anything in the OCI spec or is // irrelevant for a sandbox. - "devices": {ctrlr: &noop{}}, - "freezer": {ctrlr: &noop{}}, - "perf_event": {ctrlr: &noop{}}, - "rdma": {ctrlr: &noop{}, optional: true}, - "systemd": {ctrlr: &noop{}}, + "cpuacct": &noop{}, + "devices": &noop{}, + "freezer": &noop{}, + "perf_event": &noop{}, + "rdma": &noop{isOptional: true}, + "systemd": &noop{}, } // IsOnlyV2 checks whether cgroups V2 is enabled and V1 is not. @@ -201,31 +202,26 @@ func countCpuset(cpuset string) (int, error) { return count, nil } -// LoadPaths loads cgroup paths for given 'pid', may be set to 'self'. -func LoadPaths(pid string) (map[string]string, error) { - f, err := os.Open(filepath.Join("/proc", pid, "cgroup")) +// loadPaths loads cgroup paths for given 'pid', may be set to 'self'. +func loadPaths(pid string) (map[string]string, error) { + procCgroup, err := os.Open(filepath.Join("/proc", pid, "cgroup")) if err != nil { return nil, err } - defer f.Close() + defer procCgroup.Close() - return loadPathsHelper(f) -} - -func loadPathsHelper(cgroup io.Reader) (map[string]string, error) { - // For nested containers, in /proc/self/cgroup we see paths from host, - // which don't exist in container, so recover the container paths here by - // double-checking with /proc/pid/mountinfo - mountinfo, err := os.Open("/proc/self/mountinfo") + // Load mountinfo for the current process, because it's where cgroups is + // being accessed from. + mountinfo, err := os.Open(filepath.Join("/proc/self/mountinfo")) if err != nil { return nil, err } defer mountinfo.Close() - return loadPathsHelperWithMountinfo(cgroup, mountinfo) + return loadPathsHelper(procCgroup, mountinfo) } -func loadPathsHelperWithMountinfo(cgroup, mountinfo io.Reader) (map[string]string, error) { +func loadPathsHelper(cgroup, mountinfo io.Reader) (map[string]string, error) { paths := make(map[string]string) scanner := bufio.NewScanner(cgroup) @@ -242,34 +238,51 @@ func loadPathsHelperWithMountinfo(cgroup, mountinfo io.Reader) (map[string]strin for _, ctrlr := range strings.Split(tokens[1], ",") { // Remove prefix for cgroups with no controller, eg. systemd. ctrlr = strings.TrimPrefix(ctrlr, "name=") - paths[ctrlr] = tokens[2] + // Discard unknown controllers. + if _, ok := controllers[ctrlr]; ok { + paths[ctrlr] = tokens[2] + } } } if err := scanner.Err(); err != nil { return nil, err } - mfScanner := bufio.NewScanner(mountinfo) - for mfScanner.Scan() { - txt := mfScanner.Text() - fields := strings.Fields(txt) + // For nested containers, in /proc/[pid]/cgroup we see paths from host, + // which don't exist in container, so recover the container paths here by + // double-checking with /proc/[pid]/mountinfo + mountScanner := bufio.NewScanner(mountinfo) + for mountScanner.Scan() { + // Format: ID parent major:minor root mount-point options opt-fields - fs-type source super-options + // Example: 39 32 0:34 / /sys/fs/cgroup/devices rw,noexec shared:18 - cgroup cgroup rw,devices + fields := strings.Fields(mountScanner.Text()) if len(fields) < 9 || fields[len(fields)-3] != "cgroup" { + // Skip mounts that are not cgroup mounts. continue } - for _, opt := range strings.Split(fields[len(fields)-1], ",") { + // Cgroup controller type is in the super-options field. + superOptions := strings.Split(fields[len(fields)-1], ",") + for _, opt := range superOptions { // Remove prefix for cgroups with no controller, eg. systemd. opt = strings.TrimPrefix(opt, "name=") + + // Only considers cgroup controllers that are registered, and skip other + // irrelevant options, e.g. rw. if cgroupPath, ok := paths[opt]; ok { - root := fields[3] - relCgroupPath, err := filepath.Rel(root, cgroupPath) - if err != nil { - return nil, err + rootDir := fields[3] + if rootDir != "/" { + // When cgroup is in submount, remove repeated path components from + // cgroup path to avoid duplicating them. + relCgroupPath, err := filepath.Rel(rootDir, cgroupPath) + if err != nil { + return nil, err + } + paths[opt] = relCgroupPath } - paths[opt] = relCgroupPath } } } - if err := mfScanner.Err(); err != nil { + if err := mountScanner.Err(); err != nil { return nil, err } @@ -279,37 +292,50 @@ func loadPathsHelperWithMountinfo(cgroup, mountinfo io.Reader) (map[string]strin // Cgroup represents a group inside all controllers. For example: // Name='/foo/bar' maps to /sys/fs/cgroup/<controller>/foo/bar on // all controllers. +// +// If Name is relative, it uses the parent cgroup path to determine the +// location. For example: +// Name='foo/bar' and Parent[ctrl]="/user.slice", then it will map to +// /sys/fs/cgroup/<ctrl>/user.slice/foo/bar type Cgroup struct { Name string `json:"name"` Parents map[string]string `json:"parents"` Own map[string]bool `json:"own"` } -// New creates a new Cgroup instance if the spec includes a cgroup path. -// Returns nil otherwise. -func New(spec *specs.Spec) (*Cgroup, error) { +// NewFromSpec creates a new Cgroup instance if the spec includes a cgroup path. +// Returns nil otherwise. Cgroup paths are loaded based on the current process. +func NewFromSpec(spec *specs.Spec) (*Cgroup, error) { if spec.Linux == nil || spec.Linux.CgroupsPath == "" { return nil, nil } - return NewFromPath(spec.Linux.CgroupsPath) + return new("self", spec.Linux.CgroupsPath) } -// NewFromPath creates a new Cgroup instance. -func NewFromPath(cgroupsPath string) (*Cgroup, error) { +// NewFromPid loads cgroup for the given process. +func NewFromPid(pid int) (*Cgroup, error) { + return new(strconv.Itoa(pid), "") +} + +func new(pid, cgroupsPath string) (*Cgroup, error) { var parents map[string]string + + // If path is relative, load cgroup paths for the process to build the + // relative paths. if !filepath.IsAbs(cgroupsPath) { var err error - parents, err = LoadPaths("self") + parents, err = loadPaths(pid) if err != nil { return nil, fmt.Errorf("finding current cgroups: %w", err) } } - own := make(map[string]bool) - return &Cgroup{ + cg := &Cgroup{ Name: cgroupsPath, Parents: parents, - Own: own, - }, nil + Own: make(map[string]bool), + } + log.Debugf("New cgroup for pid: %s, %+v", pid, cg) + return cg, nil } // Install creates and configures cgroups according to 'res'. If cgroup path @@ -323,8 +349,8 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error { clean := cleanup.Make(func() { _ = c.Uninstall() }) defer clean.Clean() - for key, cfg := range controllers { - path := c.makePath(key) + for key, ctrlr := range controllers { + path := c.MakePath(key) if _, err := os.Stat(path); err == nil { // If cgroup has already been created; it has been setup by caller. Don't // make any changes to configuration, just join when sandbox/gofer starts. @@ -336,13 +362,16 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error { c.Own[key] = true if err := os.MkdirAll(path, 0755); err != nil { - if cfg.optional && errors.Is(err, unix.EROFS) { + if ctrlr.optional() && errors.Is(err, unix.EROFS) { + if err := ctrlr.skip(res); err != nil { + return err + } log.Infof("Skipping cgroup %q", key) continue } return err } - if err := cfg.ctrlr.set(res, path); err != nil { + if err := ctrlr.set(res, path); err != nil { return err } } @@ -359,7 +388,7 @@ func (c *Cgroup) Uninstall() error { // cgroup is managed by caller, don't touch it. continue } - path := c.makePath(key) + path := c.MakePath(key) log.Debugf("Removing cgroup controller for key=%q path=%q", key, path) // If we try to remove the cgroup too soon after killing the @@ -387,7 +416,7 @@ func (c *Cgroup) Uninstall() error { func (c *Cgroup) Join() (func(), error) { // First save the current state so it can be restored. undo := func() {} - paths, err := LoadPaths("self") + paths, err := loadPaths("self") if err != nil { return undo, err } @@ -414,14 +443,13 @@ func (c *Cgroup) Join() (func(), error) { } // Now join the cgroups. - for key, cfg := range controllers { - path := c.makePath(key) + for key, ctrlr := range controllers { + path := c.MakePath(key) log.Debugf("Joining cgroup %q", path) - // Writing the value 0 to a cgroup.procs file causes the - // writing process to be moved to the corresponding cgroup. - // - cgroups(7). + // Writing the value 0 to a cgroup.procs file causes the writing process to + // be moved to the corresponding cgroup - cgroups(7). if err := setValue(path, "cgroup.procs", "0"); err != nil { - if cfg.optional && os.IsNotExist(err) { + if ctrlr.optional() && os.IsNotExist(err) { continue } return undo, err @@ -432,7 +460,7 @@ func (c *Cgroup) Join() (func(), error) { // CPUQuota returns the CFS CPU quota. func (c *Cgroup) CPUQuota() (float64, error) { - path := c.makePath("cpu") + path := c.MakePath("cpu") quota, err := getInt(path, "cpu.cfs_quota_us") if err != nil { return -1, err @@ -449,7 +477,7 @@ func (c *Cgroup) CPUQuota() (float64, error) { // CPUUsage returns the total CPU usage of the cgroup. func (c *Cgroup) CPUUsage() (uint64, error) { - path := c.makePath("cpuacct") + path := c.MakePath("cpuacct") usage, err := getValue(path, "cpuacct.usage") if err != nil { return 0, err @@ -459,7 +487,7 @@ func (c *Cgroup) CPUUsage() (uint64, error) { // NumCPU returns the number of CPUs configured in 'cpuset/cpuset.cpus'. func (c *Cgroup) NumCPU() (int, error) { - path := c.makePath("cpuset") + path := c.MakePath("cpuset") cpuset, err := getValue(path, "cpuset.cpus") if err != nil { return 0, err @@ -469,7 +497,7 @@ func (c *Cgroup) NumCPU() (int, error) { // MemoryLimit returns the memory limit. func (c *Cgroup) MemoryLimit() (uint64, error) { - path := c.makePath("memory") + path := c.MakePath("memory") limStr, err := getValue(path, "memory.limit_in_bytes") if err != nil { return 0, err @@ -477,7 +505,8 @@ func (c *Cgroup) MemoryLimit() (uint64, error) { return strconv.ParseUint(strings.TrimSpace(limStr), 10, 64) } -func (c *Cgroup) makePath(controllerName string) string { +// MakePath builds a path to the given controller. +func (c *Cgroup) MakePath(controllerName string) string { path := c.Name if parent, ok := c.Parents[controllerName]; ok { path = filepath.Join(parent, c.Name) @@ -485,22 +514,48 @@ func (c *Cgroup) makePath(controllerName string) string { return filepath.Join(cgroupRoot, controllerName, path) } -type config struct { - ctrlr controller - optional bool -} - type controller interface { + // optional controllers don't fail if not found. + optional() bool + // set applies resource limits to controller. set(*specs.LinuxResources, string) error + // skip is called when controller is not found to check if it can be safely + // skipped or not based on the spec. + skip(*specs.LinuxResources) error +} + +type noop struct { + isOptional bool } -type noop struct{} +func (n *noop) optional() bool { + return n.isOptional +} func (*noop) set(*specs.LinuxResources, string) error { return nil } -type memory struct{} +func (n *noop) skip(*specs.LinuxResources) error { + if !n.isOptional { + panic("cgroup controller is not optional") + } + return nil +} + +type mandatory struct{} + +func (*mandatory) optional() bool { + return false +} + +func (*mandatory) skip(*specs.LinuxResources) error { + panic("cgroup controller is not optional") +} + +type memory struct { + mandatory +} func (*memory) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Memory == nil { @@ -533,7 +588,9 @@ func (*memory) set(spec *specs.LinuxResources, path string) error { return nil } -type cpu struct{} +type cpu struct { + mandatory +} func (*cpu) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.CPU == nil { @@ -554,7 +611,9 @@ func (*cpu) set(spec *specs.LinuxResources, path string) error { return setOptionalValueInt(path, "cpu.rt_runtime_us", spec.CPU.RealtimeRuntime) } -type cpuSet struct{} +type cpuSet struct { + mandatory +} func (*cpuSet) set(spec *specs.LinuxResources, path string) error { // cpuset.cpus and mems are required fields, but are not set on a new cgroup. @@ -576,7 +635,9 @@ func (*cpuSet) set(spec *specs.LinuxResources, path string) error { return setValue(path, "cpuset.mems", spec.CPU.Mems) } -type blockIO struct{} +type blockIO struct { + mandatory +} func (*blockIO) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.BlockIO == nil { @@ -628,6 +689,10 @@ func setThrottle(path, name string, devs []specs.LinuxThrottleDevice) error { type networkClass struct{} +func (*networkClass) optional() bool { + return true +} + func (*networkClass) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Network == nil { return nil @@ -635,8 +700,19 @@ func (*networkClass) set(spec *specs.LinuxResources, path string) error { return setOptionalValueUint32(path, "net_cls.classid", spec.Network.ClassID) } +func (*networkClass) skip(spec *specs.LinuxResources) error { + if spec != nil && spec.Network != nil && spec.Network.ClassID != nil { + return fmt.Errorf("Network.ClassID set but net_cls cgroup controller not found") + } + return nil +} + type networkPrio struct{} +func (*networkPrio) optional() bool { + return true +} + func (*networkPrio) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Network == nil { return nil @@ -650,7 +726,16 @@ func (*networkPrio) set(spec *specs.LinuxResources, path string) error { return nil } -type pids struct{} +func (*networkPrio) skip(spec *specs.LinuxResources) error { + if spec != nil && spec.Network != nil && len(spec.Network.Priorities) > 0 { + return fmt.Errorf("Network.Priorities set but net_prio cgroup controller not found") + } + return nil +} + +type pids struct { + mandatory +} func (*pids) set(spec *specs.LinuxResources, path string) error { if spec == nil || spec.Pids == nil || spec.Pids.Limit <= 0 { @@ -662,6 +747,17 @@ func (*pids) set(spec *specs.LinuxResources, path string) error { type hugeTLB struct{} +func (*hugeTLB) optional() bool { + return true +} + +func (*hugeTLB) skip(spec *specs.LinuxResources) error { + if spec != nil && len(spec.HugepageLimits) > 0 { + return fmt.Errorf("HugepageLimits set but hugetlb cgroup controller not found") + } + return nil +} + func (*hugeTLB) set(spec *specs.LinuxResources, path string) error { if spec == nil { return nil diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go index 48d71cfa6..02cadeef4 100644 --- a/runsc/cgroup/cgroup_test.go +++ b/runsc/cgroup/cgroup_test.go @@ -43,19 +43,19 @@ var debianMountinfo = ` ` var dindMountinfo = ` -1305 1304 0:64 / /sys/fs/cgroup rw - tmpfs tmpfs rw,mode=755 -1306 1305 0:32 /docker/136 /sys/fs/cgroup/systemd ro master:11 - cgroup cgroup rw,xattr,name=systemd -1307 1305 0:36 /docker/136 /sys/fs/cgroup/cpu,cpuacct ro master:16 - cgroup cgroup rw,cpu,cpuacct -1308 1305 0:37 /docker/136 /sys/fs/cgroup/freezer ro master:17 - cgroup cgroup rw,freezer -1309 1305 0:38 /docker/136 /sys/fs/cgroup/hugetlb ro master:18 - cgroup cgroup rw,hugetlb -1310 1305 0:39 /docker/136 /sys/fs/cgroup/cpuset ro master:19 - cgroup cgroup rw,cpuset -1311 1305 0:40 /docker/136 /sys/fs/cgroup/net_cls,net_prio ro master:20 - cgroup cgroup rw,net_cls,net_prio -1312 1305 0:41 /docker/136 /sys/fs/cgroup/pids ro master:21 - cgroup cgroup rw,pids -1313 1305 0:42 /docker/136 /sys/fs/cgroup/perf_event ro master:22 - cgroup cgroup rw,perf_event -1314 1305 0:43 /docker/136 /sys/fs/cgroup/memory ro master:23 - cgroup cgroup rw,memory -1316 1305 0:44 /docker/136 /sys/fs/cgroup/blkio ro master:24 - cgroup cgroup rw,blkio -1317 1305 0:45 /docker/136 /sys/fs/cgroup/devices ro master:25 - cgroup cgroup rw,devices -1318 1305 0:46 / /sys/fs/cgroup/rdma ro master:26 - cgroup cgroup rw,rdma +05 04 0:64 / /sys/fs/cgroup rw - tmpfs tmpfs rw,mode=755 +06 05 0:32 /docker/136 /sys/fs/cgroup/systemd ro master:11 - cgroup cgroup rw,xattr,name=systemd +07 05 0:36 /docker/136 /sys/fs/cgroup/cpu,cpuacct ro master:16 - cgroup cgroup rw,cpu,cpuacct +08 05 0:37 /docker/136 /sys/fs/cgroup/freezer ro master:17 - cgroup cgroup rw,freezer +09 05 0:38 /docker/136 /sys/fs/cgroup/hugetlb ro master:18 - cgroup cgroup rw,hugetlb +10 05 0:39 /docker/136 /sys/fs/cgroup/cpuset ro master:19 - cgroup cgroup rw,cpuset +11 05 0:40 /docker/136 /sys/fs/cgroup/net_cls,net_prio ro master:20 - cgroup cgroup rw,net_cls,net_prio +12 05 0:41 /docker/136 /sys/fs/cgroup/pids ro master:21 - cgroup cgroup rw,pids +13 05 0:42 /docker/136 /sys/fs/cgroup/perf_event ro master:22 - cgroup cgroup rw,perf_event +14 05 0:43 /docker/136 /sys/fs/cgroup/memory ro master:23 - cgroup cgroup rw,memory +16 05 0:44 /docker/136 /sys/fs/cgroup/blkio ro master:24 - cgroup cgroup rw,blkio +17 05 0:45 /docker/136 /sys/fs/cgroup/devices ro master:25 - cgroup cgroup rw,devices +18 05 0:46 / /sys/fs/cgroup/rdma ro master:26 - cgroup cgroup rw,rdma ` func TestUninstallEnoent(t *testing.T) { @@ -693,36 +693,42 @@ func TestLoadPaths(t *testing.T) { err string }{ { - name: "abs-path-unknown-controller", - cgroups: "0:ctr:/path", + name: "empty", mountinfo: debianMountinfo, - want: map[string]string{"ctr": "/path"}, + }, + { + name: "abs-path", + cgroups: "0:cpu:/path", + mountinfo: debianMountinfo, + want: map[string]string{"cpu": "/path"}, }, { name: "rel-path", - cgroups: "0:ctr:rel-path", + cgroups: "0:cpu:rel-path", mountinfo: debianMountinfo, - want: map[string]string{"ctr": "rel-path"}, + want: map[string]string{"cpu": "rel-path"}, }, { name: "non-controller", cgroups: "0:name=systemd:/path", mountinfo: debianMountinfo, - want: map[string]string{"systemd": "path"}, + want: map[string]string{"systemd": "/path"}, }, { - name: "empty", + name: "unknown-controller", + cgroups: "0:ctr:/path", mountinfo: debianMountinfo, + want: map[string]string{}, }, { name: "multiple", - cgroups: "0:ctr0:/path0\n" + - "1:ctr1:/path1\n" + + cgroups: "0:cpu:/path0\n" + + "1:memory:/path1\n" + "2::/empty\n", mountinfo: debianMountinfo, want: map[string]string{ - "ctr0": "/path0", - "ctr1": "/path1", + "cpu": "/path0", + "memory": "/path1", }, }, { @@ -747,10 +753,10 @@ func TestLoadPaths(t *testing.T) { }, { name: "nested-cgroup", - cgroups: `9:memory:/docker/136 -2:cpu,cpuacct:/docker/136 -1:name=systemd:/docker/136 -0::/system.slice/containerd.service`, + cgroups: "9:memory:/docker/136\n" + + "2:cpu,cpuacct:/docker/136\n" + + "1:name=systemd:/docker/136\n" + + "0::/system.slice/containerd.service\n", mountinfo: dindMountinfo, // we want relative path to /sys/fs/cgroup inside the nested container. // Subcroup inside the container will be created at /sys/fs/cgroup/cpu @@ -781,15 +787,15 @@ func TestLoadPaths(t *testing.T) { }, { name: "invalid-rel-path-in-proc-cgroup", - cgroups: "9:memory:./invalid", + cgroups: "9:memory:invalid", mountinfo: dindMountinfo, - err: "can't make ./invalid relative to /docker/136", + err: "can't make invalid relative to /docker/136", }, } { t.Run(tc.name, func(t *testing.T) { r := strings.NewReader(tc.cgroups) mountinfo := strings.NewReader(tc.mountinfo) - got, err := loadPathsHelperWithMountinfo(r, mountinfo) + got, err := loadPathsHelper(r, mountinfo) if len(tc.err) == 0 { if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -813,3 +819,47 @@ func TestLoadPaths(t *testing.T) { }) } } + +func TestOptional(t *testing.T) { + for _, tc := range []struct { + name string + ctrlr controller + spec *specs.LinuxResources + err string + }{ + { + name: "net-cls", + ctrlr: &networkClass{}, + spec: &specs.LinuxResources{Network: &specs.LinuxNetwork{ClassID: uint32Ptr(1)}}, + err: "Network.ClassID set but net_cls cgroup controller not found", + }, + { + name: "net-prio", + ctrlr: &networkPrio{}, + spec: &specs.LinuxResources{Network: &specs.LinuxNetwork{ + Priorities: []specs.LinuxInterfacePriority{ + {Name: "foo", Priority: 1}, + }, + }}, + err: "Network.Priorities set but net_prio cgroup controller not found", + }, + { + name: "hugetlb", + ctrlr: &hugeTLB{}, + spec: &specs.LinuxResources{HugepageLimits: []specs.LinuxHugepageLimit{ + {Pagesize: "1", Limit: 2}, + }}, + err: "HugepageLimits set but hugetlb cgroup controller not found", + }, + } { + t.Run(tc.name, func(t *testing.T) { + err := tc.ctrlr.skip(tc.spec) + if err == nil { + t.Fatalf("ctrlr.skip() didn't fail") + } + if !strings.Contains(err.Error(), tc.err) { + t.Errorf("ctrlr.skip() want: *%s*, got: %q", tc.err, err) + } + }) + } +} diff --git a/runsc/cli/BUILD b/runsc/cli/BUILD index f1e3cce68..360e3cea6 100644 --- a/runsc/cli/BUILD +++ b/runsc/cli/BUILD @@ -10,8 +10,10 @@ go_library( "//runsc:__pkg__", ], deps = [ + "//pkg/coverage", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/platform", "//runsc/cmd", "//runsc/config", diff --git a/runsc/cli/main.go b/runsc/cli/main.go index a3c515f4b..76184cd9c 100644 --- a/runsc/cli/main.go +++ b/runsc/cli/main.go @@ -27,8 +27,10 @@ import ( "github.com/google/subcommands" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/coverage" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/runsc/cmd" "gvisor.dev/gvisor/runsc/config" @@ -50,6 +52,7 @@ var ( logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") + coverageFD = flag.Int("coverage-fd", -1, "file descriptor to write Go coverage output.") ) // Main is the main entrypoint. @@ -86,6 +89,7 @@ func Main(version string) { subcommands.Register(new(cmd.Symbolize), "") subcommands.Register(new(cmd.Wait), "") subcommands.Register(new(cmd.Mitigate), "") + subcommands.Register(new(cmd.VerityPrepare), "") // Register internal commands with the internal group name. This causes // them to be sorted below the user-facing commands with empty group. @@ -204,6 +208,10 @@ func Main(version string) { } else if conf.AlsoLogToStderr { e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)} } + if *coverageFD >= 0 { + f := os.NewFile(uintptr(*coverageFD), "coverage file") + coverage.EnableReport(f) + } log.SetTarget(e) @@ -233,6 +241,9 @@ func Main(version string) { // Call the subcommand and pass in the configuration. var ws unix.WaitStatus subcmdCode := subcommands.Execute(context.Background(), conf, &ws) + // Check for leaks and write coverage report before os.Exit(). + refsvfs2.DoLeakCheck() + coverage.Report() if subcmdCode == subcommands.ExitSuccess { log.Infof("Exiting with status: %v", ws) if ws.Signaled() { diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 2c3b4058b..39c8ff603 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -23,6 +23,7 @@ go_library( "kill.go", "list.go", "mitigate.go", + "mitigate_extras.go", "path.go", "pause.go", "ps.go", @@ -35,6 +36,7 @@ go_library( "statefile.go", "symbolize.go", "syscalls.go", + "verity_prepare.go", "wait.go", ], visibility = [ diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go index 455c57692..5485db149 100644 --- a/runsc/cmd/do.go +++ b/runsc/cmd/do.go @@ -126,9 +126,8 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su Hostname: hostname, } - specutils.LogSpec(spec) - cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000)) + if conf.Network == config.NetworkNone { addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace}) @@ -154,55 +153,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su } } - out, err := json.Marshal(spec) - if err != nil { - return Errorf("Error to marshal spec: %v", err) - } - tmpDir, err := ioutil.TempDir("", "runsc-do") - if err != nil { - return Errorf("Error to create tmp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - log.Infof("Changing configuration RootDir to %q", tmpDir) - conf.RootDir = tmpDir - - cfgPath := filepath.Join(tmpDir, "config.json") - if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil { - return Errorf("Error write spec: %v", err) - } - - containerArgs := container.Args{ - ID: cid, - Spec: spec, - BundleDir: tmpDir, - Attached: true, - } - ct, err := container.New(conf, containerArgs) - if err != nil { - return Errorf("creating container: %v", err) - } - defer ct.Destroy() - - if err := ct.Start(conf); err != nil { - return Errorf("starting container: %v", err) - } - - // Forward signals to init in the container. Thus if we get SIGINT from - // ^C, the container gracefully exit, and we can clean up. - // - // N.B. There is a still a window before this where a signal may kill - // this process, skipping cleanup. - stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */) - defer stopForwarding() - - ws, err := ct.Wait() - if err != nil { - return Errorf("waiting for container: %v", err) - } - - *waitStatus = ws - return subcommands.ExitSuccess + return startContainerAndWait(spec, conf, cid, waitStatus) } func addNamespace(spec *specs.Spec, ns specs.LinuxNamespace) { @@ -397,3 +348,58 @@ func calculatePeerIP(ip string) (string, error) { } return fmt.Sprintf("%s.%s.%s.%d", parts[0], parts[1], parts[2], n), nil } + +func startContainerAndWait(spec *specs.Spec, conf *config.Config, cid string, waitStatus *unix.WaitStatus) subcommands.ExitStatus { + specutils.LogSpec(spec) + + out, err := json.Marshal(spec) + if err != nil { + return Errorf("Error to marshal spec: %v", err) + } + tmpDir, err := ioutil.TempDir("", "runsc-do") + if err != nil { + return Errorf("Error to create tmp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + log.Infof("Changing configuration RootDir to %q", tmpDir) + conf.RootDir = tmpDir + + cfgPath := filepath.Join(tmpDir, "config.json") + if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil { + return Errorf("Error write spec: %v", err) + } + + containerArgs := container.Args{ + ID: cid, + Spec: spec, + BundleDir: tmpDir, + Attached: true, + } + + ct, err := container.New(conf, containerArgs) + if err != nil { + return Errorf("creating container: %v", err) + } + defer ct.Destroy() + + if err := ct.Start(conf); err != nil { + return Errorf("starting container: %v", err) + } + + // Forward signals to init in the container. Thus if we get SIGINT from + // ^C, the container gracefully exit, and we can clean up. + // + // N.B. There is a still a window before this where a signal may kill + // this process, skipping cleanup. + stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */) + defer stopForwarding() + + ws, err := ct.Wait() + if err != nil { + return Errorf("waiting for container: %v", err) + } + + *waitStatus = ws + return subcommands.ExitSuccess +} diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 4cb0164dd..6a755ecb6 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -176,7 +176,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) mountIdx := 1 // first one is the root for _, m := range spec.Mounts { - if specutils.Is9PMount(m) { + if specutils.Is9PMount(m, conf.VFS2) { cfg := fsgofer.Config{ ROMount: isReadonlyMount(m.Options) || conf.Overlay, HostUDS: conf.FSGoferHostUDS, @@ -350,7 +350,7 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error { // creates directories as needed. func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { for _, m := range mounts { - if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) { + if !specutils.Is9PMount(m, conf.VFS2) { continue } @@ -390,7 +390,7 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) { cleanMounts := make([]specs.Mount, 0, len(mounts)) for _, m := range mounts { - if m.Type != "bind" || !specutils.IsVFS1SupportedDevMount(m) { + if !specutils.Is9PMount(m, conf.VFS2) { cleanMounts = append(cleanMounts, m) continue } diff --git a/runsc/cmd/mitigate.go b/runsc/cmd/mitigate.go index 720141aa5..f4e65adb8 100644 --- a/runsc/cmd/mitigate.go +++ b/runsc/cmd/mitigate.go @@ -41,8 +41,8 @@ type Mitigate struct { reverse bool // Path to file to read to create CPUSet. path string - // Callback to check if a given thread is vulnerable. - vulnerable func(other mitigate.Thread) bool + // Extra data for post mitigate operations. + data string } // Name implements subcommands.command.name. @@ -55,19 +55,20 @@ func (*Mitigate) Synopsis() string { return "mitigate mitigates the underlying system against side channel attacks" } -// Usage implments Usage for cmd.Mitigate. +// Usage implements Usage for cmd.Mitigate. func (m Mitigate) Usage() string { - return `mitigate [flags] + return fmt.Sprintf(`mitigate [flags] mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot. -The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.` +The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.%s`, m.usage()) } // SetFlags sets flags for the command Mitigate. func (m *Mitigate) SetFlags(f *flag.FlagSet) { f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system") f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs") + m.setFlags(f) } // Execute implements subcommands.Command.Execute. @@ -87,13 +88,17 @@ func (m *Mitigate) Execute(_ context.Context, f *flag.FlagSet, args ...interface m.path = allPossibleCPUs } - m.vulnerable = func(other mitigate.Thread) bool { - return other.IsVulnerable() + set, err := m.doExecute() + if err != nil { + return Errorf("Execute failed: %v", err) } - if _, err := m.doExecute(); err != nil { - log.Warningf("Execute failed: %v", err) - return subcommands.ExitFailure + if m.data == "" { + return subcommands.ExitSuccess + } + + if err = m.postMitigate(set); err != nil { + return Errorf("Post Mitigate failed: %v", err) } return subcommands.ExitSuccess @@ -104,32 +109,26 @@ func (m *Mitigate) doExecute() (mitigate.CPUSet, error) { if m.dryRun { log.Infof("Running with DryRun. No cpu settings will be changed.") } + data, err := ioutil.ReadFile(m.path) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %w", m.path, err) + } if m.reverse { - data, err := ioutil.ReadFile(m.path) - if err != nil { - return nil, fmt.Errorf("failed to read %s: %v", m.path, err) - } - set, err := m.doReverse(data) if err != nil { - return nil, fmt.Errorf("reverse operation failed: %v", err) + return nil, fmt.Errorf("reverse operation failed: %w", err) } return set, nil } - - data, err := ioutil.ReadFile(m.path) - if err != nil { - return nil, fmt.Errorf("failed to read %s: %v", m.path, err) - } set, err := m.doMitigate(data) if err != nil { - return nil, fmt.Errorf("mitigate operation failed: %v", err) + return nil, fmt.Errorf("mitigate operation failed: %w", err) } return set, nil } func (m *Mitigate) doMitigate(data []byte) (mitigate.CPUSet, error) { - set, err := mitigate.NewCPUSet(data, m.vulnerable) + set, err := mitigate.NewCPUSet(data) if err != nil { return nil, err } @@ -145,7 +144,7 @@ func (m *Mitigate) doMitigate(data []byte) (mitigate.CPUSet, error) { continue } if err := t.Disable(); err != nil { - return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err) + return nil, fmt.Errorf("error disabling thread: %s err: %w", t, err) } } log.Infof("Shutdown successful.") @@ -170,7 +169,7 @@ func (m *Mitigate) doReverse(data []byte) (mitigate.CPUSet, error) { continue } if err := t.Enable(); err != nil { - return nil, fmt.Errorf("error enabling thread: %s err: %v", t, err) + return nil, fmt.Errorf("error enabling thread: %s err: %w", t, err) } } log.Infof("Enable successful.") diff --git a/runsc/cmd/mitigate_extras.go b/runsc/cmd/mitigate_extras.go new file mode 100644 index 000000000..2cb2833f0 --- /dev/null +++ b/runsc/cmd/mitigate_extras.go @@ -0,0 +1,33 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "gvisor.dev/gvisor/runsc/flag" + "gvisor.dev/gvisor/runsc/mitigate" +) + +// usage returns any extra bits of the usage string. +func (m *Mitigate) usage() string { + return "" +} + +// setFlags sets extra flags for the command Mitigate. +func (m *Mitigate) setFlags(f *flag.FlagSet) {} + +// postMitigate handles any postMitigate actions. +func (m *Mitigate) postMitigate(_ mitigate.CPUSet) error { + return nil +} diff --git a/runsc/cmd/mitigate_test.go b/runsc/cmd/mitigate_test.go index 54211ce32..2d3fef7c1 100644 --- a/runsc/cmd/mitigate_test.go +++ b/runsc/cmd/mitigate_test.go @@ -23,7 +23,6 @@ import ( "strings" "testing" - "gvisor.dev/gvisor/runsc/mitigate" "gvisor.dev/gvisor/runsc/mitigate/mock" ) @@ -86,9 +85,6 @@ power management::84 t.Run(tc.name, func(t *testing.T) { m := &Mitigate{ dryRun: true, - vulnerable: func(other mitigate.Thread) bool { - return other.IsVulnerable() - }, } m.doExecuteTest(t, "Mitigate", tc.mitigateData, tc.mitigateCPU, tc.mitigateError) @@ -106,9 +102,6 @@ func TestExecuteSmoke(t *testing.T) { m := &Mitigate{ dryRun: true, - vulnerable: func(other mitigate.Thread) bool { - return other.IsVulnerable() - }, } m.doExecuteTest(t, "Mitigate", string(smokeMitigate), 0, nil) diff --git a/runsc/cmd/symbolize.go b/runsc/cmd/symbolize.go index fc0c69358..0fa4bfda1 100644 --- a/runsc/cmd/symbolize.go +++ b/runsc/cmd/symbolize.go @@ -65,13 +65,15 @@ func (c *Symbolize) Execute(_ context.Context, f *flag.FlagSet, args ...interfac f.Usage() return subcommands.ExitUsageError } - if !coverage.KcovAvailable() { + if !coverage.Available() { return Errorf("symbolize can only be used when coverage is available.") } coverage.InitCoverageData() if c.dumpAll { - coverage.WriteAllBlocks(os.Stdout) + if err := coverage.WriteAllBlocks(os.Stdout); err != nil { + return Errorf("Failed to write out blocks: %v", err) + } return subcommands.ExitSuccess } diff --git a/runsc/cmd/verity_prepare.go b/runsc/cmd/verity_prepare.go new file mode 100644 index 000000000..66128b2a3 --- /dev/null +++ b/runsc/cmd/verity_prepare.go @@ -0,0 +1,108 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "context" + "fmt" + "math/rand" + "os" + + "github.com/google/subcommands" + specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/runsc/config" + "gvisor.dev/gvisor/runsc/flag" + "gvisor.dev/gvisor/runsc/specutils" +) + +// VerityPrepare implements subcommands.Commands for the "verity-prepare" +// command. It sets up a sandbox with a writable verity mount mapped to "--dir", +// and executes the verity measure tool specified by "--tool" in the sandbox. It +// is intended to prepare --dir to be mounted as a verity filesystem. +type VerityPrepare struct { + root string + tool string + dir string +} + +// Name implements subcommands.Command.Name. +func (*VerityPrepare) Name() string { + return "verity-prepare" +} + +// Synopsis implements subcommands.Command.Synopsis. +func (*VerityPrepare) Synopsis() string { + return "Generates the data structures necessary to enable verityfs on a filesystem." +} + +// Usage implements subcommands.Command.Usage. +func (*VerityPrepare) Usage() string { + return "verity-prepare --tool=<measure_tool> --dir=<path>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (c *VerityPrepare) SetFlags(f *flag.FlagSet) { + f.StringVar(&c.root, "root", "/", `path to the root directory, defaults to "/"`) + f.StringVar(&c.tool, "tool", "", "path to the verity measure_tool") + f.StringVar(&c.dir, "dir", "", "path to the directory to be hashed") +} + +// Execute implements subcommands.Command.Execute. +func (c *VerityPrepare) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + conf := args[0].(*config.Config) + waitStatus := args[1].(*unix.WaitStatus) + + hostname, err := os.Hostname() + if err != nil { + return Errorf("Error to retrieve hostname: %v", err) + } + + // Map the entire host file system. + absRoot, err := resolvePath(c.root) + if err != nil { + return Errorf("Error resolving root: %v", err) + } + + spec := &specs.Spec{ + Root: &specs.Root{ + Path: absRoot, + }, + Process: &specs.Process{ + Cwd: absRoot, + Args: []string{c.tool, "--path", "/verityroot"}, + Env: os.Environ(), + Capabilities: specutils.AllCapabilities(), + }, + Hostname: hostname, + Mounts: []specs.Mount{ + specs.Mount{ + Source: c.dir, + Destination: "/verityroot", + Type: "bind", + Options: []string{"verity.roothash="}, + }, + }, + } + + cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000)) + + // Force no networking, it is not necessary to run the verity measure tool. + conf.Network = config.NetworkNone + + conf.Verity = true + + return startContainerAndWait(spec, conf, cid, waitStatus) +} diff --git a/runsc/config/config.go b/runsc/config/config.go index 1e5858837..fa550ebf7 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -55,6 +55,9 @@ type Config struct { // PanicLog is the path to log GO's runtime messages, if not empty. PanicLog string `flag:"panic-log"` + // CoverageReport is the path to write Go coverage information, if not empty. + CoverageReport string `flag:"coverage-report"` + // DebugLogFormat is the log format for debug. DebugLogFormat string `flag:"debug-log-format"` @@ -172,6 +175,9 @@ type Config struct { // Enables seccomp inside the sandbox. OCISeccomp bool `flag:"oci-seccomp"` + // Mounts the cgroup filesystem backed by the sentry's cgroupfs. + Cgroupfs bool `flag:"cgroupfs"` + // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in // tests. It allows runsc to start the sandbox process as the current // user, and without chrooting the sandbox process. This can be diff --git a/runsc/config/flags.go b/runsc/config/flags.go index 1d996c841..c3dca2352 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -44,7 +44,8 @@ func RegisterFlags() { // Debugging flags. flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") - flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") + flag.String("panic-log", "", "file path where panic reports and other Go's runtime messages are written.") + flag.String("coverage-report", "", "file path where Go coverage reports are written. Reports will only be generated if runsc is built with --collect_code_coverage and --instrumentation_filter Bazel flags.") flag.Bool("log-packets", false, "enable network packet logging.") flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") flag.Bool("alsologtostderr", false, "send log messages to stderr.") @@ -75,6 +76,7 @@ func RegisterFlags() { flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") flag.Bool("vfs2", false, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.") flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") + flag.Bool("cgroupfs", false, "Automatically mount cgroupfs.") // Flags that control sandbox runtime behavior: network related. flag.Var(networkTypePtr(NetworkSandbox), "network", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 3620dc8c3..5314549d6 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -51,9 +51,7 @@ go_test( ], library = ":container", shard_count = more_shards, - tags = [ - "requires-kvm", - ], + tags = ["requires-kvm"], deps = [ "//pkg/abi/linux", "//pkg/bits", diff --git a/runsc/container/container.go b/runsc/container/container.go index f9d83c118..0820edaec 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -233,7 +233,7 @@ func New(conf *config.Config, args Args) (*Container, error) { } // Create and join cgroup before processes are created to ensure they are // part of the cgroup from the start (and all their children processes). - cg, err := cgroup.New(args.Spec) + cg, err := cgroup.NewFromSpec(args.Spec) if err != nil { return nil, err } @@ -886,7 +886,7 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu // Add root mount and then add any other additional mounts. mountCount := 1 for _, m := range spec.Mounts { - if specutils.Is9PMount(m) { + if specutils.Is9PMount(m, conf.VFS2) { mountCount++ } } @@ -1132,7 +1132,7 @@ func (c *Container) populateStats(event *boot.EventOut) { // account for the full cgroup CPU usage. We split cgroup usage // proportionally according to the sentry-internal usage measurements, // only counting Running containers. - log.Warningf("event.ContainerUsage: %v", event.ContainerUsage) + log.Debugf("event.ContainerUsage: %v", event.ContainerUsage) var containerUsage uint64 var allContainersUsage uint64 for ID, usage := range event.ContainerUsage { @@ -1142,7 +1142,7 @@ func (c *Container) populateStats(event *boot.EventOut) { } } - cgroup, err := c.Sandbox.FindCgroup() + cgroup, err := c.Sandbox.NewCGroup() if err != nil { // No cgroup, so rely purely on the sentry's accounting. log.Warningf("events: no cgroups") @@ -1159,17 +1159,18 @@ func (c *Container) populateStats(event *boot.EventOut) { return } - // If the sentry reports no memory usage, fall back on cgroups and - // split usage equally across containers. + // If the sentry reports no CPU usage, fall back on cgroups and split usage + // equally across containers. if allContainersUsage == 0 { log.Warningf("events: no sentry CPU usage reported") allContainersUsage = cgroupsUsage containerUsage = cgroupsUsage / uint64(len(event.ContainerUsage)) } - log.Warningf("%f, %f, %f", containerUsage, cgroupsUsage, allContainersUsage) // Scaling can easily overflow a uint64 (e.g. a containerUsage and // cgroupsUsage of 16 seconds each will overflow), so use floats. - event.Event.Data.CPU.Usage.Total = uint64(float64(containerUsage) * (float64(cgroupsUsage) / float64(allContainersUsage))) + total := float64(containerUsage) * (float64(cgroupsUsage) / float64(allContainersUsage)) + log.Debugf("Usage, container: %d, cgroups: %d, all: %d, total: %.0f", containerUsage, cgroupsUsage, allContainersUsage, total) + event.Event.Data.CPU.Usage.Total = uint64(total) return } diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 5a0c468a4..0e79877b7 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -2449,6 +2449,27 @@ func TestCreateWithCorruptedStateFile(t *testing.T) { } } +func TestBindMountByOption(t *testing.T) { + for name, conf := range configs(t, all...) { + t.Run(name, func(t *testing.T) { + dir, err := ioutil.TempDir(testutil.TmpDir(), "bind-mount") + spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file")) + if err != nil { + t.Fatalf("ioutil.TempDir(): %v", err) + } + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: dir, + Source: dir, + Type: "none", + Options: []string{"rw", "bind"}, + }) + if err := run(spec, conf); err != nil { + t.Fatalf("error running sandbox: %v", err) + } + }) + } +} + func execute(cont *Container, name string, arg ...string) (unix.WaitStatus, error) { args := &control.ExecArgs{ Filename: name, diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 0f0a223ce..0dbe1e323 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cleanup" @@ -1510,7 +1511,7 @@ func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { Destination: "/mydir/test", Source: "/some/dir", Type: "tmpfs", - Options: []string{"rw", "rbind", "relatime"}, + Options: []string{"rw", "relatime"}, } podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) @@ -1917,9 +1918,9 @@ func TestMultiContainerEvent(t *testing.T) { } defer cleanup() - for _, cont := range containers { - t.Logf("Running containerd %s", cont.ID) - } + t.Logf("Running container sleep %s", containers[0].ID) + t.Logf("Running container busy %s", containers[1].ID) + t.Logf("Running container quick %s", containers[2].ID) // Wait for last container to stabilize the process count that is // checked further below. @@ -1940,50 +1941,61 @@ func TestMultiContainerEvent(t *testing.T) { } // Check events for running containers. - var prevUsage uint64 for _, cont := range containers[:2] { ret, err := cont.Event() if err != nil { - t.Errorf("Container.Events(): %v", err) + t.Errorf("Container.Event(%q): %v", cont.ID, err) } evt := ret.Event if want := "stats"; evt.Type != want { - t.Errorf("Wrong event type, want: %s, got: %s", want, evt.Type) + t.Errorf("Wrong event type, cid: %q, want: %s, got: %s", cont.ID, want, evt.Type) } if cont.ID != evt.ID { t.Errorf("Wrong container ID, want: %s, got: %s", cont.ID, evt.ID) } // One process per remaining container. if got, want := evt.Data.Pids.Current, uint64(2); got != want { - t.Errorf("Wrong number of PIDs, want: %d, got: %d", want, got) + t.Errorf("Wrong number of PIDs, cid: %q, want: %d, got: %d", cont.ID, want, got) } - // Both remaining containers should have nonzero usage, and - // 'busy' should have higher usage than 'sleep'. - usage := evt.Data.CPU.Usage.Total - if usage == 0 { - t.Errorf("Running container should report nonzero CPU usage, but got %d", usage) + // The exited container should always have a usage of zero. + if exited := ret.ContainerUsage[containers[2].ID]; exited != 0 { + t.Errorf("Exited container should report 0 CPU usage, got: %d", exited) + } + } + + // Check that CPU reported by busy container is higher than sleep. + cb := func() error { + sleepEvt, err := containers[0].Event() + if err != nil { + return &backoff.PermanentError{Err: err} } - if usage <= prevUsage { - t.Errorf("Expected container %s to use more than %d ns of CPU, but used %d", cont.ID, prevUsage, usage) + sleepUsage := sleepEvt.Event.Data.CPU.Usage.Total + + busyEvt, err := containers[1].Event() + if err != nil { + return &backoff.PermanentError{Err: err} } - t.Logf("Container %s usage: %d", cont.ID, usage) - prevUsage = usage + busyUsage := busyEvt.Event.Data.CPU.Usage.Total - // The exited container should have a usage of zero. - if exited := ret.ContainerUsage[containers[2].ID]; exited != 0 { - t.Errorf("Exited container should report 0 CPU usage, but got %d", exited) + if busyUsage <= sleepUsage { + t.Logf("Busy container usage lower than sleep (busy: %d, sleep: %d), retrying...", busyUsage, sleepUsage) + return fmt.Errorf("Busy container should have higher usage than sleep, busy: %d, sleep: %d", busyUsage, sleepUsage) } + return nil + } + // Give time for busy container to run and use more CPU than sleep. + if err := testutil.Poll(cb, 10*time.Second); err != nil { + t.Fatal(err) } - // Check that stop and destroyed containers return error. + // Check that stopped and destroyed containers return error. if err := containers[1].Destroy(); err != nil { t.Fatalf("container.Destroy: %v", err) } for _, cont := range containers[1:] { - _, err := cont.Event() - if err == nil { - t.Errorf("Container.Events() should have failed, cid:%s, state: %v", cont.ID, cont.Status) + if _, err := cont.Event(); err == nil { + t.Errorf("Container.Event() should have failed, cid: %q, state: %v", cont.ID, cont.Status) } } } diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index e04ddda47..b81ede5ae 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -21,6 +21,7 @@ package fsgofer import ( + "errors" "fmt" "io" "math" @@ -58,9 +59,6 @@ var verityXattrs = map[string]struct{}{ // join is equivalent to path.Join() but skips path.Clean() which is expensive. func join(parent, child string) string { - if child == "." || child == ".." { - panic(fmt.Sprintf("invalid child path %q", child)) - } return parent + "/" + child } @@ -1226,3 +1224,56 @@ func (l *localFile) checkROMount() error { } return nil } + +func (l *localFile) MultiGetAttr(names []string) ([]p9.FullStat, error) { + stats := make([]p9.FullStat, 0, len(names)) + + if len(names) > 0 && names[0] == "" { + qid, valid, attr, err := l.GetAttr(p9.AttrMask{}) + if err != nil { + return nil, err + } + stats = append(stats, p9.FullStat{ + QID: qid, + Valid: valid, + Attr: attr, + }) + names = names[1:] + } + + parent := l.file.FD() + for _, name := range names { + child, err := unix.Openat(parent, name, openFlags|unix.O_PATH, 0) + if parent != l.file.FD() { + // Parent is no longer needed. + _ = unix.Close(parent) + } + if err != nil { + if errors.Is(err, unix.ENOENT) { + // No pont in continuing any further. + break + } + return nil, err + } + + var stat unix.Stat_t + if err := unix.Fstat(child, &stat); err != nil { + _ = unix.Close(child) + return nil, err + } + valid, attr := l.fillAttr(&stat) + stats = append(stats, p9.FullStat{ + QID: l.attachPoint.makeQID(&stat), + Valid: valid, + Attr: attr, + }) + if (stat.Mode & unix.S_IFMT) != unix.S_IFDIR { + // Doesn't need to continue if entry is not a dir. Including symlinks + // that cannot be followed. + _ = unix.Close(child) + break + } + parent = child + } + return stats, nil +} diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index d7e141476..77723827a 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -703,16 +703,6 @@ func TestWalkNotFound(t *testing.T) { }) } -func TestWalkPanic(t *testing.T) { - runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) { - for _, name := range []string{".", ".."} { - assertPanic(t, func() { - s.file.Walk([]string{name}) - }) - } - }) -} - func TestWalkDup(t *testing.T) { runAll(t, func(t *testing.T, s state) { _, dup, err := s.file.Walk([]string{}) diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go index 24f67414c..88409af8f 100644 --- a/runsc/mitigate/mitigate.go +++ b/runsc/mitigate/mitigate.go @@ -50,7 +50,7 @@ const ( type CPUSet map[threadID]*ThreadGroup // NewCPUSet creates a CPUSet from data read from /proc/cpuinfo. -func NewCPUSet(data []byte, vulnerable func(Thread) bool) (CPUSet, error) { +func NewCPUSet(data []byte) (CPUSet, error) { processors, err := getThreads(string(data)) if err != nil { return nil, err @@ -67,7 +67,7 @@ func NewCPUSet(data []byte, vulnerable func(Thread) bool) (CPUSet, error) { core = &ThreadGroup{} set[p.id] = core } - core.isVulnerable = core.isVulnerable || vulnerable(p) + core.isVulnerable = core.isVulnerable || p.IsVulnerable() core.threads = append(core.threads, p) } @@ -446,6 +446,7 @@ func buildRegex(key, match string) *regexp.Regexp { func parseRegex(data, key, match string) (string, error) { r := buildRegex(key, match) matches := r.FindStringSubmatch(data) + if len(matches) < 2 { return "", fmt.Errorf("failed to match key %q: %q", key, data) } diff --git a/runsc/mitigate/mitigate_test.go b/runsc/mitigate/mitigate_test.go index bd5a2433f..890c65f05 100644 --- a/runsc/mitigate/mitigate_test.go +++ b/runsc/mitigate/mitigate_test.go @@ -54,14 +54,13 @@ func TestMockCPUSet(t *testing.T) { } { t.Run(tc.testCase.Name, func(t *testing.T) { data := tc.testCase.MakeCPUString() - vulnerable := func(t Thread) bool { - return t.IsVulnerable() - } - set, err := NewCPUSet([]byte(data), vulnerable) + set, err := NewCPUSet([]byte(data)) if err != nil { t.Fatalf("Failed to create cpuSet: %v", err) } + t.Logf("data: %s", data) + for _, tg := range set { if err := checkSorted(tg.threads); err != nil { t.Fatalf("Failed to sort cpuSet: %v", err) @@ -260,11 +259,7 @@ func TestReadFile(t *testing.T) { t.Fatalf("Failed to read cpuinfo: %v", err) } - vulnerable := func(t Thread) bool { - return t.IsVulnerable() - } - - set, err := NewCPUSet(data, vulnerable) + set, err := NewCPUSet(data) if err != nil { t.Fatalf("Failed to parse CPU data %v\n%s", err, data) } diff --git a/runsc/mitigate/mock/mock.go b/runsc/mitigate/mock/mock.go index 2db718cb9..12c59e356 100644 --- a/runsc/mitigate/mock/mock.go +++ b/runsc/mitigate/mock/mock.go @@ -82,6 +82,19 @@ var Haswell2core = CPU{ ThreadsPerCore: 1, } +// AMD2 is an two core AMD machine. +var AMD2 = CPU{ + Name: "AMD", + VendorID: "AuthenticAMD", + Family: 23, + Model: 49, + ModelName: "AMD EPYC 7B12", + Bugs: "sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass", + PhysicalCores: 1, + Cores: 1, + ThreadsPerCore: 2, +} + // AMD8 is an eight core AMD machine. var AMD8 = CPU{ Name: "AMD", @@ -115,15 +128,15 @@ bugs : %s for k := 0; k < tc.ThreadsPerCore; k++ { processorNum := (i*tc.Cores+j)*tc.ThreadsPerCore + k ret += fmt.Sprintf(template, - processorNum, /*processor*/ - tc.VendorID, /*vendor_id*/ - tc.Family, /*cpu family*/ - tc.Model, /*model*/ - tc.ModelName, /*model name*/ - i, /*physical id*/ - j, /*core id*/ - tc.Cores*tc.PhysicalCores, /*cpu cores*/ - tc.Bugs, /*bugs*/ + processorNum, /*processor*/ + tc.VendorID, /*vendor_id*/ + tc.Family, /*cpu family*/ + tc.Model, /*model*/ + tc.ModelName, /*model name*/ + i, /*physical id*/ + j, /*core id*/ + k, /*cpu cores*/ + tc.Bugs, /*bugs*/ ) } } diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index f0a551a1e..bc4a3fa32 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/cleanup", "//pkg/control/client", "//pkg/control/server", + "//pkg/coverage", "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/platform", diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 450f92645..8d31e33b2 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -34,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/control/client" "gvisor.dev/gvisor/pkg/control/server" + "gvisor.dev/gvisor/pkg/coverage" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/platform" @@ -309,20 +310,9 @@ func (s *Sandbox) Processes(cid string) ([]*control.Process, error) { return pl, nil } -// FindCgroup returns the sandbox's Cgroup, or an error if it does not have one. -func (s *Sandbox) FindCgroup() (*cgroup.Cgroup, error) { - paths, err := cgroup.LoadPaths(strconv.Itoa(s.Pid)) - if err != nil { - return nil, err - } - // runsc places sandboxes in the same cgroup for each controller, so we - // pick an arbitrary controller here to get the cgroup path. - const controller = "cpuacct" - controllerPath, ok := paths[controller] - if !ok { - return nil, fmt.Errorf("no %q controller found", controller) - } - return cgroup.NewFromPath(controllerPath) +// NewCGroup returns the sandbox's Cgroup, or an error if it does not have one. +func (s *Sandbox) NewCGroup() (*cgroup.Cgroup, error) { + return cgroup.NewFromPid(s.Pid) } // Execute runs the specified command in the container. It returns the PID of @@ -399,15 +389,15 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn cmd.Args = append(cmd.Args, "--log-fd="+strconv.Itoa(nextFD)) nextFD++ } - if conf.DebugLog != "" { - test := "" - if len(conf.TestOnlyTestNameEnv) != 0 { - // Fetch test name if one is provided and the test only flag was set. - if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok { - test = t - } - } + test := "" + if len(conf.TestOnlyTestNameEnv) != 0 { + // Fetch test name if one is provided and the test only flag was set. + if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok { + test = t + } + } + if conf.DebugLog != "" { debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot", test) if err != nil { return fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err) @@ -418,23 +408,29 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn nextFD++ } if conf.PanicLog != "" { - test := "" - if len(conf.TestOnlyTestNameEnv) != 0 { - // Fetch test name if one is provided and the test only flag was set. - if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok { - test = t - } - } - panicLogFile, err := specutils.DebugLogFile(conf.PanicLog, "panic", test) if err != nil { - return fmt.Errorf("opening debug log file in %q: %v", conf.PanicLog, err) + return fmt.Errorf("opening panic log file in %q: %v", conf.PanicLog, err) } defer panicLogFile.Close() cmd.ExtraFiles = append(cmd.ExtraFiles, panicLogFile) cmd.Args = append(cmd.Args, "--panic-log-fd="+strconv.Itoa(nextFD)) nextFD++ } + covFilename := conf.CoverageReport + if covFilename == "" { + covFilename = os.Getenv("GO_COVERAGE_FILE") + } + if covFilename != "" && coverage.Available() { + covFile, err := specutils.DebugLogFile(covFilename, "cov", test) + if err != nil { + return fmt.Errorf("opening debug log file in %q: %v", covFilename, err) + } + defer covFile.Close() + cmd.ExtraFiles = append(cmd.ExtraFiles, covFile) + cmd.Args = append(cmd.Args, "--coverage-fd="+strconv.Itoa(nextFD)) + nextFD++ + } // Add the "boot" command to the args. // @@ -486,7 +482,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn } if deviceFile, err := gPlatform.OpenDevice(); err != nil { - return fmt.Errorf("opening device file for platform %q: %v", gPlatform, err) + return fmt.Errorf("opening device file for platform %q: %v", conf.Platform, err) } else if deviceFile != nil { defer deviceFile.Close() cmd.ExtraFiles = append(cmd.ExtraFiles, deviceFile) @@ -1174,7 +1170,7 @@ func deviceFileForPlatform(name string) (*os.File, error) { f, err := p.OpenDevice() if err != nil { - return nil, fmt.Errorf("opening device file for platform %q: %v", p, err) + return nil, fmt.Errorf("opening device file for platform %q: %w", name, err) } return f, nil } diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go index b62504a8c..9ecd0fde6 100644 --- a/runsc/specutils/fs.go +++ b/runsc/specutils/fs.go @@ -18,6 +18,7 @@ import ( "fmt" "math/bits" "path" + "strings" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" @@ -64,6 +65,12 @@ var optionsMap = map[string]mapping{ "sync": {set: true, val: unix.MS_SYNCHRONOUS}, } +// verityMountOptions is the set of valid verity mount option keys. +var verityMountOptions = map[string]struct{}{ + "verity.roothash": struct{}{}, + "verity.action": struct{}{}, +} + // propOptionsMap is similar to optionsMap, but it lists propagation options // that cannot be used together with other flags. var propOptionsMap = map[string]mapping{ @@ -117,6 +124,14 @@ func validateMount(mnt *specs.Mount) error { return nil } +func moptKey(opt string) string { + if len(opt) == 0 { + return opt + } + // Guaranteed to have at least one token, since opt is not empty. + return strings.SplitN(opt, "=", 2)[0] +} + // ValidateMountOptions validates that mount options are correct. func ValidateMountOptions(opts []string) error { for _, o := range opts { @@ -125,7 +140,8 @@ func ValidateMountOptions(opts []string) error { } _, ok1 := optionsMap[o] _, ok2 := propOptionsMap[o] - if !ok1 && !ok2 { + _, ok3 := verityMountOptions[moptKey(o)] + if !ok1 && !ok2 && !ok3 { return fmt.Errorf("unknown mount option %q", o) } if err := validatePropagation(o); err != nil { diff --git a/runsc/specutils/seccomp/BUILD b/runsc/specutils/seccomp/BUILD index e9e647d82..c5f5b863e 100644 --- a/runsc/specutils/seccomp/BUILD +++ b/runsc/specutils/seccomp/BUILD @@ -28,8 +28,10 @@ go_test( srcs = ["seccomp_test.go"], library = ":seccomp", deps = [ - "//pkg/binary", + "//pkg/abi/linux", "//pkg/bpf", + "//pkg/hostarch", + "//pkg/marshal", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/runsc/specutils/seccomp/seccomp_test.go b/runsc/specutils/seccomp/seccomp_test.go index 11a6c8daa..20796bf14 100644 --- a/runsc/specutils/seccomp/seccomp_test.go +++ b/runsc/specutils/seccomp/seccomp_test.go @@ -20,20 +20,15 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" ) -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - -// asInput converts a seccompData to a bpf.Input. -func asInput(d seccompData) bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +// asInput converts a linux.SeccompData to a bpf.Input. +func asInput(d *linux.SeccompData) bpf.Input { + return bpf.InputBytes{marshal.Marshal(d), hostarch.ByteOrder} } // testInput creates an Input struct with given seccomp input values. @@ -49,13 +44,13 @@ func testInput(arch uint32, syscallName string, args *[6]uint64) bpf.Input { args = &argArray } - data := seccompData{ - nr: syscallNo, - arch: arch, - args: *args, + data := linux.SeccompData{ + Nr: int32(syscallNo), + Arch: arch, + Args: *args, } - return asInput(data) + return asInput(&data) } // testCase holds a seccomp test case. @@ -100,7 +95,7 @@ var ( }, // Syscall matches but the arch is AUDIT_ARCH_X86 so the return // value is the bad arch action. - input: asInput(seccompData{nr: 183, arch: 0x40000003}), // + input: asInput(&linux.SeccompData{Nr: 183, Arch: 0x40000003}), // expected: uint32(killThreadAction), }, { diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 45856fd58..11b476690 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -332,14 +332,38 @@ func capsFromNames(names []string, skipSet map[linux.Capability]struct{}) (auth. return auth.CapabilitySetOfMany(caps), nil } -// Is9PMount returns true if the given mount can be mounted as an external gofer. -func Is9PMount(m specs.Mount) bool { - return m.Type == "bind" && m.Source != "" && IsVFS1SupportedDevMount(m) +// Is9PMount returns true if the given mount can be mounted as an external +// gofer. +func Is9PMount(m specs.Mount, vfs2Enabled bool) bool { + MaybeConvertToBindMount(&m) + return m.Type == "bind" && m.Source != "" && IsSupportedDevMount(m, vfs2Enabled) } -// IsVFS1SupportedDevMount returns true if m.Destination does not specify a +// MaybeConvertToBindMount converts mount type to "bind" in case any of the +// mount options are either "bind" or "rbind" as required by the OCI spec. +// +// "For bind mounts (when options include either bind or rbind), the type is a +// dummy, often "none" (not listed in /proc/filesystems)." +func MaybeConvertToBindMount(m *specs.Mount) { + if m.Type == "bind" { + return + } + for _, opt := range m.Options { + if opt == "bind" || opt == "rbind" { + m.Type = "bind" + return + } + } +} + +// IsSupportedDevMount returns true if m.Destination does not specify a // path that is hardcoded by VFS1's implementation of /dev. -func IsVFS1SupportedDevMount(m specs.Mount) bool { +func IsSupportedDevMount(m specs.Mount, vfs2Enabled bool) bool { + // VFS2 has no hardcoded files under /dev, so everything is allowed. + if vfs2Enabled { + return true + } + // See pkg/sentry/fs/dev/dev.go. var existingDevices = []string{ "/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr", diff --git a/shim/BUILD b/shim/BUILD index 434269d31..695f61eb9 100644 --- a/shim/BUILD +++ b/shim/BUILD @@ -6,6 +6,7 @@ go_binary( name = "containerd-shim-runsc-v1", srcs = ["main.go"], static = True, + tags = ["staging"], visibility = [ "//visibility:public", ], diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD index 697ab5837..a5a3cf2c1 100644 --- a/test/benchmarks/base/BUILD +++ b/test/benchmarks/base/BUILD @@ -17,7 +17,6 @@ go_library( benchmark_test( name = "startup_test", - size = "enormous", srcs = ["startup_test.go"], visibility = ["//:sandbox"], deps = [ @@ -29,7 +28,6 @@ benchmark_test( benchmark_test( name = "size_test", - size = "enormous", srcs = ["size_test.go"], visibility = ["//:sandbox"], deps = [ @@ -42,7 +40,6 @@ benchmark_test( benchmark_test( name = "sysbench_test", - size = "enormous", srcs = ["sysbench_test.go"], visibility = ["//:sandbox"], deps = [ diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD index 0b1743603..fee2695ff 100644 --- a/test/benchmarks/database/BUILD +++ b/test/benchmarks/database/BUILD @@ -11,7 +11,6 @@ go_library( benchmark_test( name = "redis_test", - size = "enormous", srcs = ["redis_test.go"], library = ":database", visibility = ["//:sandbox"], diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD index dc82e63b2..c2b981a07 100644 --- a/test/benchmarks/fs/BUILD +++ b/test/benchmarks/fs/BUILD @@ -4,7 +4,6 @@ package(licenses = ["notice"]) benchmark_test( name = "bazel_test", - size = "enormous", srcs = ["bazel_test.go"], visibility = ["//:sandbox"], deps = [ @@ -18,7 +17,6 @@ benchmark_test( benchmark_test( name = "fio_test", - size = "enormous", srcs = ["fio_test.go"], visibility = ["//:sandbox"], deps = [ diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD index 380783f0b..ad2ef3a55 100644 --- a/test/benchmarks/media/BUILD +++ b/test/benchmarks/media/BUILD @@ -11,7 +11,6 @@ go_library( benchmark_test( name = "ffmpeg_test", - size = "enormous", srcs = ["ffmpeg_test.go"], library = ":media", visibility = ["//:sandbox"], diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD index 3425b8dad..56a4d4f39 100644 --- a/test/benchmarks/ml/BUILD +++ b/test/benchmarks/ml/BUILD @@ -11,7 +11,6 @@ go_library( benchmark_test( name = "tensorflow_test", - size = "enormous", srcs = ["tensorflow_test.go"], library = ":ml", visibility = ["//:sandbox"], diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD index 2741570f5..e047020bf 100644 --- a/test/benchmarks/network/BUILD +++ b/test/benchmarks/network/BUILD @@ -18,7 +18,6 @@ go_library( benchmark_test( name = "iperf_test", - size = "enormous", srcs = [ "iperf_test.go", ], @@ -34,7 +33,6 @@ benchmark_test( benchmark_test( name = "node_test", - size = "enormous", srcs = [ "node_test.go", ], @@ -49,7 +47,6 @@ benchmark_test( benchmark_test( name = "ruby_test", - size = "enormous", srcs = [ "ruby_test.go", ], @@ -64,7 +61,6 @@ benchmark_test( benchmark_test( name = "nginx_test", - size = "enormous", srcs = [ "nginx_test.go", ], @@ -79,7 +75,6 @@ benchmark_test( benchmark_test( name = "httpd_test", - size = "enormous", srcs = [ "httpd_test.go", ], diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 29a84f184..1e9792b4f 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -8,13 +8,12 @@ go_test( srcs = [ "exec_test.go", "integration_test.go", - "regression_test.go", ], library = ":integration", tags = [ # Requires docker and runsc to be configured before the test runs. - "manual", "local", + "manual", ], visibility = ["//:sandbox"], deps = [ diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 49cd74887..1accc3b3b 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -168,13 +168,6 @@ func TestCheckpointRestore(t *testing.T) { t.Skip("Pause/resume is not supported.") } - // TODO(gvisor.dev/issue/3373): Remove after implementing. - if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 { - t.Skip("CheckpointRestore not implemented in VFS2.") - } else if err != nil { - t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err) - } - ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) @@ -399,15 +392,15 @@ func TestTmpFile(t *testing.T) { // TestTmpMount checks that mounts inside '/tmp' are not overridden. func TestTmpMount(t *testing.T) { - ctx := context.Background() dir, err := ioutil.TempDir(testutil.TmpDir(), "tmp-mount") if err != nil { t.Fatalf("TempDir(): %v", err) } - want := "123" + const want = "123" if err := ioutil.WriteFile(filepath.Join(dir, "file.txt"), []byte("123"), 0666); err != nil { t.Fatalf("WriteFile(): %v", err) } + ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) @@ -430,6 +423,48 @@ func TestTmpMount(t *testing.T) { } } +// Test that it is allowed to mount a file on top of /dev files, e.g. +// /dev/random. +func TestMountOverDev(t *testing.T) { + if usingVFS2, err := dockerutil.UsingVFS2(); !usingVFS2 { + t.Skip("VFS1 doesn't allow /dev/random to be mounted.") + } else if err != nil { + t.Fatalf("Failed to read config for runtime %s: %v", dockerutil.Runtime(), err) + } + + random, err := ioutil.TempFile(testutil.TmpDir(), "random") + if err != nil { + t.Fatal("ioutil.TempFile() failed:", err) + } + const want = "123" + if _, err := random.WriteString(want); err != nil { + t.Fatalf("WriteString() to %q: %v", random.Name(), err) + } + + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + opts := dockerutil.RunOpts{ + Image: "basic/alpine", + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: random.Name(), + Target: "/dev/random", + }, + }, + } + cmd := "dd count=1 bs=5 if=/dev/random 2> /dev/null" + got, err := d.Run(ctx, opts, "sh", "-c", cmd) + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + if want != got { + t.Errorf("invalid file content, want: %q, got: %q", want, got) + } +} + // TestSyntheticDirs checks that submounts can be created inside a readonly // mount even if the target path does not exist. func TestSyntheticDirs(t *testing.T) { @@ -550,6 +585,30 @@ func runIntegrationTest(t *testing.T, capAdd []string, args ...string) { } } +// Test that UDS can be created using overlay when parent directory is in lower +// layer only (b/134090485). +// +// Prerequisite: the directory where the socket file is created must not have +// been open for write before bind(2) is called. +func TestBindOverlay(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Run the container. + got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/ubuntu", + }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + + // Check the output contains what we want. + if want := "foobar-asdf"; !strings.Contains(got, want) { + t.Fatalf("docker run output is missing %q: %s", want, got) + } +} + func TestMain(m *testing.M) { dockerutil.EnsureSupportedDockerVersion() flag.Parse() diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go deleted file mode 100644 index 84564cdaa..000000000 --- a/test/e2e/regression_test.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package integration - -import ( - "context" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/test/dockerutil" -) - -// Test that UDS can be created using overlay when parent directory is in lower -// layer only (b/134090485). -// -// Prerequisite: the directory where the socket file is created must not have -// been open for write before bind(2) is called. -func TestBindOverlay(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - // Run the container. - got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/ubuntu", - }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p") - if err != nil { - t.Fatalf("docker run failed: %v", err) - } - - // Check the output contains what we want. - if want := "foobar-asdf"; !strings.Contains(got, want) { - t.Fatalf("docker run output is missing %q: %s", want, got) - } -} diff --git a/test/fsstress/BUILD b/test/fsstress/BUILD index d262c8554..e74e7fff2 100644 --- a/test/fsstress/BUILD +++ b/test/fsstress/BUILD @@ -14,9 +14,7 @@ go_test( "manual", "local", ], - deps = [ - "//pkg/test/dockerutil", - ], + deps = ["//pkg/test/dockerutil"], ) go_library( diff --git a/test/fsstress/fsstress_test.go b/test/fsstress/fsstress_test.go index 300c21ceb..d53c8f90d 100644 --- a/test/fsstress/fsstress_test.go +++ b/test/fsstress/fsstress_test.go @@ -17,7 +17,9 @@ package fsstress import ( "context" + "flag" "math/rand" + "os" "strconv" "strings" "testing" @@ -30,33 +32,44 @@ func init() { rand.Seed(int64(time.Now().Nanosecond())) } -func fsstress(t *testing.T, dir string) { +func TestMain(m *testing.M) { + dockerutil.EnsureSupportedDockerVersion() + flag.Parse() + os.Exit(m.Run()) +} + +type config struct { + operations string + processes string + target string +} + +func fsstress(t *testing.T, conf config) { ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) - const ( - operations = "10000" - processes = "100" - image = "basic/fsstress" - ) + const image = "basic/fsstress" seed := strconv.FormatUint(uint64(rand.Uint32()), 10) - args := []string{"-d", dir, "-n", operations, "-p", processes, "-s", seed, "-X"} - t.Logf("Repro: docker run --rm --runtime=runsc %s %s", image, strings.Join(args, "")) + args := []string{"-d", conf.target, "-n", conf.operations, "-p", conf.processes, "-s", seed, "-X"} + t.Logf("Repro: docker run --rm --runtime=%s gvisor.dev/images/%s %s", dockerutil.Runtime(), image, strings.Join(args, " ")) out, err := d.Run(ctx, dockerutil.RunOpts{Image: image}, args...) if err != nil { t.Fatalf("docker run failed: %v\noutput: %s", err, out) } - lines := strings.SplitN(out, "\n", 2) - if len(lines) > 1 || !strings.HasPrefix(out, "seed =") { + // This is to catch cases where fsstress spews out error messages during clean + // up but doesn't return error. + if len(out) > 0 { t.Fatalf("unexpected output: %s", out) } } -func TestFsstressGofer(t *testing.T) { - fsstress(t, "/test") -} - func TestFsstressTmpfs(t *testing.T) { - fsstress(t, "/tmp") + // This takes between 10s to run on my machine. Adjust as needed. + cfg := config{ + operations: "5000", + processes: "20", + target: "/tmp", + } + fsstress(t, cfg) } diff --git a/test/image/image_test.go b/test/image/image_test.go index 968e62f63..952264173 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -183,7 +183,10 @@ func TestMysql(t *testing.T) { // Start the container. if err := server.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/mysql", - Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"}, + Env: []string{ + "MYSQL_ROOT_PASSWORD=foobar123", + "MYSQL_ROOT_HOST=%", // Allow anyone to connect to the server. + }, }); err != nil { t.Fatalf("docker run failed: %v", err) } diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index d6c69a319..04d112134 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -456,3 +456,11 @@ func TestNATPreRECVORIGDSTADDR(t *testing.T) { func TestNATOutRECVORIGDSTADDR(t *testing.T) { singleTest(t, &NATOutRECVORIGDSTADDR{}) } + +func TestNATPostSNATUDP(t *testing.T) { + singleTest(t, &NATPostSNATUDP{}) +} + +func TestNATPostSNATTCP(t *testing.T) { + singleTest(t, &NATPostSNATTCP{}) +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index bba17b894..4590e169d 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -69,29 +69,41 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error { return nil } -// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for -// the first read on that port. +// listenUDP listens on a UDP port and returns nil if the first read from that +// port is successful. func listenUDP(ctx context.Context, port int, ipv6 bool) error { + _, err := listenUDPFrom(ctx, port, ipv6) + return err +} + +// listenUDPFrom listens on a UDP port and returns the sender's UDP address if +// the first read from that port is successful. +func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) { localAddr := net.UDPAddr{ Port: port, } conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr) if err != nil { - return err + return nil, err } defer conn.Close() - ch := make(chan error) + type result struct { + remoteAddr *net.UDPAddr + err error + } + + ch := make(chan result) go func() { - _, err = conn.Read([]byte{0}) - ch <- err + _, remoteAddr, err := conn.ReadFromUDP([]byte{0}) + ch <- result{remoteAddr, err} }() select { - case err := <-ch: - return err + case res := <-ch: + return res.remoteAddr, res.err case <-ctx.Done(): - return ctx.Err() + return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err()) } } @@ -125,8 +137,16 @@ func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error { } } -// listenTCP listens for connections on a TCP port. +// listenTCP listens for connections on a TCP port, and returns nil if a +// connection is established. func listenTCP(ctx context.Context, port int, ipv6 bool) error { + _, err := listenTCPFrom(ctx, port, ipv6) + return err +} + +// listenTCP listens for connections on a TCP port, and returns the remote +// TCP address if a connection is established. +func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) { localAddr := net.TCPAddr{ Port: port, } @@ -134,23 +154,32 @@ func listenTCP(ctx context.Context, port int, ipv6 bool) error { // Starts listening on port. lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr) if err != nil { - return err + return nil, err } defer lConn.Close() + type result struct { + remoteAddr net.Addr + err error + } + // Accept connections on port. - ch := make(chan error) + ch := make(chan result) go func() { conn, err := lConn.AcceptTCP() - ch <- err + var remoteAddr net.Addr + if err == nil { + remoteAddr = conn.RemoteAddr() + } + ch <- result{remoteAddr, err} conn.Close() }() select { - case err := <-ch: - return err + case res := <-ch: + return res.remoteAddr, res.err case <-ctx.Done(): - return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err()) + return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err()) } } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 0776639a7..0f25b6a18 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "net" + "strconv" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" @@ -48,6 +49,8 @@ func init() { RegisterTestCase(&NATOutOriginalDst{}) RegisterTestCase(&NATPreRECVORIGDSTADDR{}) RegisterTestCase(&NATOutRECVORIGDSTADDR{}) + RegisterTestCase(&NATPostSNATUDP{}) + RegisterTestCase(&NATPostSNATTCP{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -486,7 +489,12 @@ func (*NATLoopbackSkipsPrerouting) Name() string { // ContainerAction implements TestCase.ContainerAction. func (*NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect anything sent to localhost to an unused port. - dest := []byte{127, 0, 0, 1} + var dest net.IP + if ipv6 { + dest = net.IPv6loopback + } else { + dest = net.IPv4(127, 0, 0, 1) + } if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } @@ -915,3 +923,115 @@ func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) er } return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs) } + +const ( + snatAddrV4 = "194.236.50.155" + snatAddrV6 = "2a0a::1" + snatPort = 43 +) + +// NATPostSNATUDP tests that the source port/IP in the packets are modified as expected. +type NATPostSNATUDP struct{ localCase } + +var _ TestCase = (*NATPostSNATUDP)(nil) + +// Name implements TestCase.Name. +func (*NATPostSNATUDP) Name() string { + return "NATPostSNATUDP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (*NATPostSNATUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + var source string + if ipv6 { + source = fmt.Sprintf("[%s]:%d", snatAddrV6, snatPort) + } else { + source = fmt.Sprintf("%s:%d", snatAddrV4, snatPort) + } + + if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "udp", "-j", "SNAT", "--to-source", source); err != nil { + return err + } + return sendUDPLoop(ctx, ip, acceptPort, ipv6) +} + +// LocalAction implements TestCase.LocalAction. +func (*NATPostSNATUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + remote, err := listenUDPFrom(ctx, acceptPort, ipv6) + if err != nil { + return err + } + var snatAddr string + if ipv6 { + snatAddr = snatAddrV6 + } else { + snatAddr = snatAddrV4 + } + if got, want := remote.IP, net.ParseIP(snatAddr); !got.Equal(want) { + return fmt.Errorf("got remote address = %s, want = %s", got, want) + } + if got, want := remote.Port, snatPort; got != want { + return fmt.Errorf("got remote port = %d, want = %d", got, want) + } + return nil +} + +// NATPostSNATTCP tests that the source port/IP in the packets are modified as +// expected. +type NATPostSNATTCP struct{ localCase } + +var _ TestCase = (*NATPostSNATTCP)(nil) + +// Name implements TestCase.Name. +func (*NATPostSNATTCP) Name() string { + return "NATPostSNATTCP" +} + +// ContainerAction implements TestCase.ContainerAction. +func (*NATPostSNATTCP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + addrs, err := getInterfaceAddrs(ipv6) + if err != nil { + return err + } + var source string + for _, addr := range addrs { + if addr.To4() != nil { + if !ipv6 { + source = fmt.Sprintf("%s:%d", addr, snatPort) + } + } else if ipv6 && addr.IsGlobalUnicast() { + source = fmt.Sprintf("[%s]:%d", addr, snatPort) + } + } + if source == "" { + return fmt.Errorf("can't find any interface address to use") + } + + if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "tcp", "-j", "SNAT", "--to-source", source); err != nil { + return err + } + return connectTCP(ctx, ip, acceptPort, ipv6) +} + +// LocalAction implements TestCase.LocalAction. +func (*NATPostSNATTCP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + remote, err := listenTCPFrom(ctx, acceptPort, ipv6) + if err != nil { + return err + } + HostStr, portStr, err := net.SplitHostPort(remote.String()) + if err != nil { + return err + } + if got, want := HostStr, ip.String(); got != want { + return fmt.Errorf("got remote address = %s, want = %s", got, want) + } + port, err := strconv.ParseInt(portStr, 10, 0) + if err != nil { + return err + } + if got, want := int(port), snatPort; got != want { + return fmt.Errorf("got remote port = %d, want = %d", got, want) + } + return nil +} diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD index 5d95516ee..de66cbe6d 100644 --- a/test/packetdrill/BUILD +++ b/test/packetdrill/BUILD @@ -41,6 +41,7 @@ packetdrill_test( test_suite( name = "all_tests", tags = [ + "local", "manual", "packetdrill", ], diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 34e83ec49..afe73a69a 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -246,6 +246,15 @@ ALL_TESTS = [ expect_netstack_failure = True, ), PacketimpactTestInfo( + name = "tcp_listen_backlog", + ), + PacketimpactTestInfo( + name = "tcp_syncookie", + ), + PacketimpactTestInfo( + name = "tcp_connect_icmp_error", + ), + PacketimpactTestInfo( name = "icmpv6_param_problem", ), PacketimpactTestInfo( diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go index b271bd47e..4fb2f5c4b 100644 --- a/test/packetimpact/runner/dut.go +++ b/test/packetimpact/runner/dut.go @@ -369,30 +369,32 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co "--dut_infos_json", string(dutInfosBytes), ) testbenchLogs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) - if (err != nil) != expectFailure { - var dutLogs string - for i, dut := range duts { - logs, err := dut.Logs(ctx) - if err != nil { - logs = fmt.Sprintf("failed to fetch DUT logs: %s", err) - } - dutLogs = fmt.Sprintf(`%s====== Begin of DUT-%d Logs ====== + var dutLogs string + for i, dut := range duts { + logs, err := dut.Logs(ctx) + if err != nil { + logs = fmt.Sprintf("failed to fetch DUT logs: %s", err) + } + dutLogs = fmt.Sprintf(`%s====== Begin of DUT-%d Logs ====== %s ====== End of DUT-%d Logs ====== `, dutLogs, i, logs, i) - } - - t.Errorf(`test error: %v, expect failure: %t - + } + testLogs := fmt.Sprintf(` %s====== Begin of Testbench Logs ====== %s -====== End of Testbench Logs ======`, - err, expectFailure, dutLogs, testbenchLogs) +====== End of Testbench Logs ======`, dutLogs, testbenchLogs) + if (err != nil) != expectFailure { + t.Errorf(`test error: %v, expect failure: %t +%s`, err, expectFailure, testLogs) + } else if expectFailure { + t.Logf(`test failed as expected: %v +%s`, err, testLogs) } } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 92103c1e9..c4fe293e0 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -385,6 +385,36 @@ packetimpact_testbench( ], ) +packetimpact_testbench( + name = "tcp_listen_backlog", + srcs = ["tcp_listen_backlog_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( + name = "tcp_syncookie", + srcs = ["tcp_syncookie_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( + name = "tcp_connect_icmp_error", + srcs = ["tcp_connect_icmp_error_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + validate_all_tests() [packetimpact_go_test( @@ -396,6 +426,7 @@ validate_all_tests() test_suite( name = "all_tests", tags = [ + "local", "manual", "packetimpact", ], diff --git a/test/packetimpact/tests/tcp_connect_icmp_error_test.go b/test/packetimpact/tests/tcp_connect_icmp_error_test.go new file mode 100644 index 000000000..79bfe9eb7 --- /dev/null +++ b/test/packetimpact/tests/tcp_connect_icmp_error_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_connect_icmp_error_test + +import ( + "context" + "flag" + "sync" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +func sendICMPError(t *testing.T, conn *testbench.TCPIPv4, tcp *testbench.TCP) { + t.Helper() + + layers := conn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + ip, ok := tcp.Prev().(*testbench.IPv4) + if !ok { + t.Fatalf("expected %s to be IPv4", tcp.Prev()) + } + icmpErr := &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)} + + layers = append(layers, icmpErr, ip, tcp) + conn.SendFrameStateless(t, layers) +} + +// TestTCPConnectICMPError tests for the handshake to fail and the socket state +// cleaned up on receiving an ICMP error. +func TestTCPConnectICMPError(t *testing.T) { + dut := testbench.NewDUT(t) + + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, dut.Net.RemoteIPv4) + port := uint16(9001) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port}) + defer conn.Close(t) + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], dut.Net.LocalIPv4) + // Bring the dut to SYN-SENT state with a non-blocking connect. + dut.Connect(t, clientFD, &sa) + tcp, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second) + if err != nil { + t.Fatalf("expected SYN, %s", err) + } + + done := make(chan bool) + defer close(done) + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + _, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + block.Done() + for { + select { + case <-done: + return + default: + if errno := dut.GetSockOptInt(t, clientFD, unix.SOL_SOCKET, unix.SO_ERROR); errno != 0 { + return + } + } + } + }() + block.Wait() + + sendICMPError(t, &conn, tcp) + + dut.PollOne(t, clientFD, unix.POLLHUP, time.Second) + + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + // The DUT should reply with RST to our ACK as the state should have + // transitioned to CLOSED because of handshake error. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected RST, %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_listen_backlog_test.go b/test/packetimpact/tests/tcp_listen_backlog_test.go new file mode 100644 index 000000000..26c812d0a --- /dev/null +++ b/test/packetimpact/tests/tcp_listen_backlog_test.go @@ -0,0 +1,86 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_listen_backlog_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +// TestTCPListenBacklog tests for a listening endpoint behavior: +// (1) reply to more SYNs than what is configured as listen backlog +// (2) ignore ACKs (that complete a handshake) when the accept queue is full +// (3) ignore incoming SYNs when the accept queue is full +func TestTCPListenBacklog(t *testing.T) { + dut := testbench.NewDUT(t) + + // Listening endpoint accepts one more connection than the listen backlog. + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 0 /*backlog*/) + + var establishedConn testbench.TCPIPv4 + var incompleteConn testbench.TCPIPv4 + + // Test if the DUT listener replies to more SYNs than listen backlog+1 + for i, conn := range []*testbench.TCPIPv4{&establishedConn, &incompleteConn} { + *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + // Expect dut connection to have transitioned to SYN-RCVD state. + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK for %d connection, %s", i, err) + } + } + defer establishedConn.Close(t) + defer incompleteConn.Close(t) + + // Send the ACK to complete handshake. + establishedConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + dut.PollOne(t, listenFd, unix.POLLIN, time.Second) + + // Send the ACK to complete handshake, expect this to be ignored by the + // listener. + incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + + // Drain the accept queue to enable poll for subsequent connections on the + // listener. + dut.Accept(t, listenFd) + + // The ACK for the incomplete connection should be ignored by the + // listening endpoint and the poll on listener should now time out. + if pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, time.Second); len(pfds) != 0 { + t.Fatalf("got dut.Poll(...) = %#v", pfds) + } + + // Re-send the ACK to complete handshake and re-fill the accept-queue. + incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + dut.PollOne(t, listenFd, unix.POLLIN, time.Second) + + // Now initiate a new connection when the accept queue is full. + connectingConn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer connectingConn.Close(t) + // Expect dut connection to drop the SYN and let the client stay in SYN_SENT state. + connectingConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if got, err := connectingConn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err == nil { + t.Fatalf("expected no SYN-ACK, but got %s", got) + } +} diff --git a/test/packetimpact/tests/tcp_syncookie_test.go b/test/packetimpact/tests/tcp_syncookie_test.go new file mode 100644 index 000000000..1c21c62ff --- /dev/null +++ b/test/packetimpact/tests/tcp_syncookie_test.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_syncookie_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +// TestSynCookie test if the DUT listener is replying back using syn cookies. +// The test does not complete the handshake by not sending the ACK to SYNACK. +// When syncookies are not used, this forces the listener to retransmit SYNACK. +// And when syncookies are being used, there is no such retransmit. +func TestTCPSynCookie(t *testing.T) { + dut := testbench.NewDUT(t) + + // Listening endpoint accepts one more connection than the listen backlog. + _, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + + var withoutSynCookieConn testbench.TCPIPv4 + var withSynCookieConn testbench.TCPIPv4 + + // Test if the DUT listener replies to more SYNs than listen backlog+1 + for _, conn := range []*testbench.TCPIPv4{&withoutSynCookieConn, &withSynCookieConn} { + *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + } + defer withoutSynCookieConn.Close(t) + defer withSynCookieConn.Close(t) + + checkSynAck := func(t *testing.T, conn *testbench.TCPIPv4, expectRetransmit bool) { + // Expect dut connection to have transitioned to SYN-RCVD state. + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK, but got %s", err) + } + + // If the DUT listener is using syn cookies, it will not retransmit SYNACK + got, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)), Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, 2*time.Second) + if expectRetransmit && err != nil { + t.Fatalf("expected retransmitted SYN-ACK, but got %s", err) + } + if !expectRetransmit && err == nil { + t.Fatalf("expected no retransmitted SYN-ACK, but got %s", got) + } + } + + t.Run("without syncookies", func(t *testing.T) { checkSynAck(t, &withoutSynCookieConn, true /*expectRetransmit*/) }) + t.Run("with syncookies", func(t *testing.T) { checkSynAck(t, &withSynCookieConn, false /*expectRetransmit*/) }) +} diff --git a/test/perf/BUILD b/test/perf/BUILD index ed899ac22..75b5003e2 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -1,3 +1,4 @@ +load("//tools:defs.bzl", "more_shards") load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) @@ -35,8 +36,9 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", debug = False, + shard_count = more_shards, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", ) @@ -48,7 +50,7 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", debug = False, tags = ["nogotsan"], test = "//test/perf/linux:gettid_benchmark", @@ -106,7 +108,7 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", debug = False, test = "//test/perf/linux:signal_benchmark", ) @@ -124,9 +126,10 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", add_overlay = True, debug = False, + tags = ["nogotsan"], test = "//test/perf/linux:unlink_benchmark", ) diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc index db74cb264..047a034bd 100644 --- a/test/perf/linux/getpid_benchmark.cc +++ b/test/perf/linux/getpid_benchmark.cc @@ -31,6 +31,24 @@ void BM_Getpid(benchmark::State& state) { BENCHMARK(BM_Getpid); +#ifdef __x86_64__ + +#define SYSNO_STR1(x) #x +#define SYSNO_STR(x) SYSNO_STR1(x) + +// BM_GetpidOpt uses the most often pattern of calling system calls: +// mov $SYS_XXX, %eax; syscall. +void BM_GetpidOpt(benchmark::State& state) { + for (auto s : state) { + __asm__("movl $" SYSNO_STR(SYS_getpid) ", %%eax\n" + "syscall\n" + : : : "rax", "rcx", "r11"); + } +} + +BENCHMARK(BM_GetpidOpt); +#endif // __x86_64__ + } // namespace } // namespace testing diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc index 7b060c70e..d495f3ddc 100644 --- a/test/perf/linux/write_benchmark.cc +++ b/test/perf/linux/write_benchmark.cc @@ -46,6 +46,18 @@ void BM_Write(benchmark::State& state) { BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime(); +void BM_Append(benchmark::State& state) { + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY | O_APPEND)); + + const char data = 'a'; + for (auto _ : state) { + TEST_CHECK(WriteFd(fd.get(), &data, 1) == 1); + } +} + +BENCHMARK(BM_Append); + } // namespace } // namespace testing diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index a74d6b1c1..39e838582 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -308,8 +308,8 @@ func TestCgroup(t *testing.T) { } } -// TestCgroupParent sets the "CgroupParent" option and checks that the child and parent's -// cgroups are created correctly relative to each other. +// TestCgroupParent sets the "CgroupParent" option and checks that the child and +// parent's cgroups are created correctly relative to each other. func TestCgroupParent(t *testing.T) { ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) @@ -343,15 +343,19 @@ func TestCgroupParent(t *testing.T) { // Finds cgroup for the sandbox's parent process to check that cgroup is // created in the right location relative to the parent. cmd := fmt.Sprintf("grep PPid: /proc/%d/status | sed 's/PPid:\\s//'", pid) - ppid, err := exec.Command("bash", "-c", cmd).CombinedOutput() + ppidStr, err := exec.Command("bash", "-c", cmd).CombinedOutput() if err != nil { t.Fatalf("Executing %q: %v", cmd, err) } - cgroups, err := cgroup.LoadPaths(strings.TrimSpace(string(ppid))) + ppid, err := strconv.Atoi(strings.TrimSpace(string(ppidStr))) if err != nil { - t.Fatalf("cgroup.LoadPath(%s): %v", ppid, err) + t.Fatalf("invalid PID (%s): %v", ppidStr, err) } - path := filepath.Join("/sys/fs/cgroup/memory", cgroups["memory"], parent, gid, "cgroup.procs") + cgroups, err := cgroup.NewFromPid(ppid) + if err != nil { + t.Fatalf("cgroup.NewFromPid(%d): %v", ppid, err) + } + path := filepath.Join(cgroups.MakePath("cpuacct"), parent, gid, "cgroup.procs") if err := verifyPid(pid, path); err != nil { t.Errorf("cgroup control %q processes: %v", "memory", err) } diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 829247657..2a0ef2cec 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -4,7 +4,7 @@ load("//tools:defs.bzl", "default_platform", "platforms") def _runner_test_impl(ctx): # Generate a runner binary. - runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) + runner = ctx.actions.declare_file(ctx.label.name) runner_content = "\n".join([ "#!/bin/bash", "set -euf -x -o pipefail", @@ -85,18 +85,9 @@ def _syscall_test( # Add the full_platform and file access in a tag to make it easier to run # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. + tags = list(tags) tags += [full_platform, "file_" + file_access] - # Hash this target into one of 15 buckets. This can be used to - # randomly split targets between different workflows. - hash15 = hash(native.package_name() + name) % 15 - tags.append("hash15:" + str(hash15)) - - # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until - # we figure out how to request ipv4 sockets on Guitar machines. - if network == "host": - tags.append("noguitar") - # Disable off-host networking. tags.append("requires-net:loopback") tags.append("requires-net:ipv4") @@ -157,116 +148,82 @@ def syscall_test( if not tags: tags = [] - vfs2_tags = list(tags) - if vfs2: - # Add tag to easily run VFS2 tests with --test_tag_filters=vfs2 - vfs2_tags.append("vfs2") - if fuse: - vfs2_tags.append("fuse") - - else: - # Don't automatically run tests tests not yet passing. - vfs2_tags.append("manual") - vfs2_tags.append("noguitar") - vfs2_tags.append("notap") - - _syscall_test( - test = test, - platform = default_platform, - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + vfs2_tags, - debug = debug, - vfs2 = True, - fuse = fuse, - **kwargs - ) - if fuse: - # Only generate *_vfs2_fuse target if fuse parameter is enabled. - return - - _syscall_test( - test = test, - platform = "native", - use_tmpfs = False, - add_uds_tree = add_uds_tree, - tags = list(tags), - debug = debug, - **kwargs - ) - - for (platform, platform_tags) in platforms.items(): + if vfs2 and not fuse: + # Generate a vfs1 plain test. Most testing will now be + # biased towards vfs2, with only a single vfs1 case. _syscall_test( test = test, - platform = platform, + platform = default_platform, use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, - tags = platform_tags + tags, + tags = tags + platforms[default_platform], debug = debug, + vfs2 = False, **kwargs ) - if add_overlay: + if not fuse: + # Generate a native test if fuse is not required. _syscall_test( test = test, - platform = default_platform, - use_tmpfs = use_tmpfs, + platform = "native", + use_tmpfs = False, add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + tags, + tags = tags, debug = debug, - overlay = True, **kwargs ) - # TODO(gvisor.dev/issue/4407): Remove tags to enable VFS2 overlay tests. - overlay_vfs2_tags = list(vfs2_tags) - overlay_vfs2_tags.append("manual") - overlay_vfs2_tags.append("noguitar") - overlay_vfs2_tags.append("notap") + for (platform, platform_tags) in platforms.items(): _syscall_test( test = test, - platform = default_platform, + platform = platform, use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + overlay_vfs2_tags, + tags = platform_tags + tags, + fuse = fuse, + vfs2 = vfs2, debug = debug, - overlay = True, - vfs2 = True, **kwargs ) - if add_hostinet: + if add_overlay: _syscall_test( test = test, platform = default_platform, use_tmpfs = use_tmpfs, - network = "host", add_uds_tree = add_uds_tree, tags = platforms[default_platform] + tags, debug = debug, + fuse = fuse, + vfs2 = vfs2, + overlay = True, **kwargs ) - - if not use_tmpfs: - # Also test shared gofer access. + if add_hostinet: _syscall_test( test = test, platform = default_platform, use_tmpfs = use_tmpfs, + network = "host", add_uds_tree = add_uds_tree, tags = platforms[default_platform] + tags, debug = debug, - file_access = "shared", + fuse = fuse, + vfs2 = vfs2, **kwargs ) + if not use_tmpfs: + # Also test shared gofer access. _syscall_test( test = test, platform = default_platform, use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + vfs2_tags, + tags = platforms[default_platform] + tags, debug = debug, file_access = "shared", - vfs2 = True, + fuse = fuse, + vfs2 = vfs2, **kwargs ) diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go index 2ad5f58ef..38e57d62f 100644 --- a/test/runner/gtest/gtest.go +++ b/test/runner/gtest/gtest.go @@ -35,39 +35,6 @@ var ( filterBenchmarkFlag = "--benchmark_filter" ) -// BuildTestArgs builds arguments to be passed to the test binary to execute -// only the test cases in `indices`. -func BuildTestArgs(indices []int, testCases []TestCase) []string { - var testFilter, benchFilter string - for _, tci := range indices { - tc := testCases[tci] - if tc.all { - // No argument will make all tests run. - return nil - } - if tc.benchmark { - if len(benchFilter) > 0 { - benchFilter += "|" - } - benchFilter += "^" + tc.Name + "$" - } else { - if len(testFilter) > 0 { - testFilter += ":" - } - testFilter += tc.FullName() - } - } - - var args []string - if len(testFilter) > 0 { - args = append(args, fmt.Sprintf("%s=%s", filterTestFlag, testFilter)) - } - if len(benchFilter) > 0 { - args = append(args, fmt.Sprintf("%s=%s", filterBenchmarkFlag, benchFilter)) - } - return args -} - // TestCase is a single gtest test case. type TestCase struct { // Suite is the suite for this test. @@ -92,6 +59,22 @@ func (tc TestCase) FullName() string { return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) } +// Args returns arguments to be passed when invoking the test. +func (tc TestCase) Args() []string { + if tc.all { + return []string{} // No arguments. + } + if tc.benchmark { + return []string{ + fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), + fmt.Sprintf("%s=", filterTestFlag), + } + } + return []string{ + fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()), + } +} + // ParseTestCases calls a gtest test binary to list its test and returns a // slice with the name and suite of each test. // @@ -107,7 +90,6 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes // We failed to list tests with the given flags. Just // return something that will run the binary with no // flags, which should execute all tests. - fmt.Printf("failed to get test list: %v\n", err) return []TestCase{ { Suite: "Default", diff --git a/test/runner/runner.go b/test/runner/runner.go index a8a134fe2..7e8e88ba2 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -26,6 +26,7 @@ import ( "path/filepath" "strings" "syscall" + "testing" "time" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -56,82 +57,13 @@ var ( leakCheck = flag.Bool("leak-check", false, "check for reference leaks") ) -func main() { - flag.Parse() - if flag.NArg() != 1 { - fatalf("test must be provided") - } - - log.SetLevel(log.Info) - if *debug { - log.SetLevel(log.Debug) - } - - if *platform != "native" && *runscPath == "" { - if err := testutil.ConfigureExePath(); err != nil { - panic(err.Error()) - } - *runscPath = specutils.ExePath - } - - // Make sure stdout and stderr are opened with O_APPEND, otherwise logs - // from outside the sandbox can (and will) stomp on logs from inside - // the sandbox. - for _, f := range []*os.File{os.Stdout, os.Stderr} { - flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) - if err != nil { - fatalf("error getting file flags for %v: %v", f, err) - } - if flags&unix.O_APPEND == 0 { - flags |= unix.O_APPEND - if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { - fatalf("error setting file flags for %v: %v", f, err) - } - } - } - - // Resolve the absolute path for the binary. - testBin, err := filepath.Abs(flag.Args()[0]) - if err != nil { - fatalf("Abs(%q) failed: %v", flag.Args()[0], err) - } - - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin, true) - if err != nil { - fatalf("ParseTestCases(%q) failed: %v", testBin, err) - } - - // Get subset of tests corresponding to shard. - indices, err := testutil.TestIndicesForShard(len(testCases)) - if err != nil { - fatalf("TestsForShard() failed: %v", err) - } - if len(indices) == 0 { - log.Warningf("No tests to run in this shard") - return - } - args := gtest.BuildTestArgs(indices, testCases) - - switch *platform { - case "native": - if err := runTestCaseNative(testBin, args); err != nil { - fatalf(err.Error()) - } - default: - if err := runTestCaseRunsc(testBin, args); err != nil { - fatalf(err.Error()) - } - } -} - // runTestCaseNative runs the test case directly on the host machine. -func runTestCaseNative(testBin string, args []string) error { +func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { // These tests might be running in parallel, so make sure they have a // unique test temp dir. tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") if err != nil { - return fmt.Errorf("could not create temp dir: %v", err) + t.Fatalf("could not create temp dir: %v", err) } defer os.RemoveAll(tmpDir) @@ -152,12 +84,12 @@ func runTestCaseNative(testBin string, args []string) error { } // Remove shard env variables so that the gunit binary does not try to // interpret them. - env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS") + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) if *addUDSTree { socketDir, cleanup, err := uds.CreateSocketTree("/tmp") if err != nil { - return fmt.Errorf("failed to create socket tree: %v", err) + t.Fatalf("failed to create socket tree: %v", err) } defer cleanup() @@ -167,7 +99,7 @@ func runTestCaseNative(testBin string, args []string) error { env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) } - cmd := exec.Command(testBin, args...) + cmd := exec.Command(testBin, tc.Args()...) cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -183,9 +115,8 @@ func runTestCaseNative(testBin string, args []string) error { if err := cmd.Run(); err != nil { ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) - return fmt.Errorf("test exited with status %d, want 0", ws.ExitStatus()) + t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) } - return nil } // runRunsc runs spec in runsc in a standard test configuration. @@ -193,7 +124,7 @@ func runTestCaseNative(testBin string, args []string) error { // runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. // // Returns an error if the sandboxed application exits non-zero. -func runRunsc(spec *specs.Spec) error { +func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { bundleDir, cleanup, err := testutil.SetupBundleDir(spec) if err != nil { return fmt.Errorf("SetupBundleDir failed: %v", err) @@ -206,8 +137,9 @@ func runRunsc(spec *specs.Spec) error { } defer cleanup() + name := tc.FullName() id := testutil.RandomContainerID() - log.Infof("Running test in container %q", id) + log.Infof("Running test %q in container %q", name, id) specutils.LogSpec(spec) args := []string{ @@ -243,8 +175,13 @@ func runRunsc(spec *specs.Spec) error { args = append(args, "-ref-leak-mode=log-names") } - testLogDir := os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR") - if len(testLogDir) > 0 { + testLogDir := "" + if undeclaredOutputsDir, ok := unix.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { + // Create log directory dedicated for this test. + testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) + if err := os.MkdirAll(testLogDir, 0755); err != nil { + return fmt.Errorf("could not create test dir: %v", err) + } debugLogDir, err := ioutil.TempDir(testLogDir, "runsc") if err != nil { return fmt.Errorf("could not create temp dir: %v", err) @@ -252,6 +189,7 @@ func runRunsc(spec *specs.Spec) error { debugLogDir += "/" log.Infof("runsc logs: %s", debugLogDir) args = append(args, "-debug-log", debugLogDir) + args = append(args, "-coverage-report", debugLogDir) // Default -log sends messages to stderr which makes reading the test log // difficult. Instead, drop them when debug log is enabled given it's a @@ -289,7 +227,7 @@ func runRunsc(spec *specs.Spec) error { if !ok { return } - log.Warningf("Got signal: %v", s) + log.Warningf("%s: Got signal: %v", name, s) done := make(chan bool, 1) dArgs := append([]string{}, args...) dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) @@ -322,7 +260,7 @@ func runRunsc(spec *specs.Spec) error { if err == nil && len(testLogDir) > 0 { // If the test passed, then we erase the log directory. This speeds up // uploading logs in continuous integration & saves on disk space. - _ = os.RemoveAll(testLogDir) + os.RemoveAll(testLogDir) } return err @@ -377,10 +315,10 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { } // runsTestCaseRunsc runs the test case in runsc. -func runTestCaseRunsc(testBin string, args []string) error { +func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Run a new container with the test executable and filter for the // given test suite and name. - spec := testutil.NewSpecWithArgs(append([]string{testBin}, args...)...) + spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. @@ -406,12 +344,12 @@ func runTestCaseRunsc(testBin string, args []string) error { // users, so make sure it is world-accessible. tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") if err != nil { - return fmt.Errorf("could not create temp dir: %v", err) + t.Fatalf("could not create temp dir: %v", err) } defer os.RemoveAll(tmpDir) if err := os.Chmod(tmpDir, 0777); err != nil { - return fmt.Errorf("could not chmod temp dir: %v", err) + t.Fatalf("could not chmod temp dir: %v", err) } // "/tmp" is not replaced with a tmpfs mount inside the sandbox @@ -431,12 +369,13 @@ func runTestCaseRunsc(testBin string, args []string) error { // Set environment variables that indicate we are running in gVisor with // the given platform, network, and filesystem stack. - env := []string{"TEST_ON_GVISOR=" + *platform, "GVISOR_NETWORK=" + *network} - env = append(env, os.Environ()...) - const vfsVar = "GVISOR_VFS" + platformVar := "TEST_ON_GVISOR" + networkVar := "GVISOR_NETWORK" + env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) + vfsVar := "GVISOR_VFS" if *vfs2 { env = append(env, vfsVar+"=VFS2") - const fuseVar = "FUSE_ENABLED" + fuseVar := "FUSE_ENABLED" if *fuse { env = append(env, fuseVar+"=TRUE") } else { @@ -448,11 +387,11 @@ func runTestCaseRunsc(testBin string, args []string) error { // Remove shard env variables so that the gunit binary does not try to // interpret them. - env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS") + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to // be backed by tmpfs. - env = filterEnv(env, "TEST_TMPDIR") + env = filterEnv(env, []string{"TEST_TMPDIR"}) env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir)) spec.Process.Env = env @@ -460,19 +399,18 @@ func runTestCaseRunsc(testBin string, args []string) error { if *addUDSTree { cleanup, err := setupUDSTree(spec) if err != nil { - return fmt.Errorf("error creating UDS tree: %v", err) + t.Fatalf("error creating UDS tree: %v", err) } defer cleanup() } - if err := runRunsc(spec); err != nil { - return fmt.Errorf("test failed with error %v, want nil", err) + if err := runRunsc(tc, spec); err != nil { + t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) } - return nil } // filterEnv returns an environment with the excluded variables removed. -func filterEnv(env []string, exclude ...string) []string { +func filterEnv(env, exclude []string) []string { var out []string for _, kv := range env { ok := true @@ -493,3 +431,82 @@ func fatalf(s string, args ...interface{}) { fmt.Fprintf(os.Stderr, s+"\n", args...) os.Exit(1) } + +func matchString(a, b string) (bool, error) { + return a == b, nil +} + +func main() { + flag.Parse() + if flag.NArg() != 1 { + fatalf("test must be provided") + } + testBin := flag.Args()[0] // Only argument. + + log.SetLevel(log.Info) + if *debug { + log.SetLevel(log.Debug) + } + + if *platform != "native" && *runscPath == "" { + if err := testutil.ConfigureExePath(); err != nil { + panic(err.Error()) + } + *runscPath = specutils.ExePath + } + + // Make sure stdout and stderr are opened with O_APPEND, otherwise logs + // from outside the sandbox can (and will) stomp on logs from inside + // the sandbox. + for _, f := range []*os.File{os.Stdout, os.Stderr} { + flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) + if err != nil { + fatalf("error getting file flags for %v: %v", f, err) + } + if flags&unix.O_APPEND == 0 { + flags |= unix.O_APPEND + if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { + fatalf("error setting file flags for %v: %v", f, err) + } + } + } + + // Get all test cases in each binary. + testCases, err := gtest.ParseTestCases(testBin, true) + if err != nil { + fatalf("ParseTestCases(%q) failed: %v", testBin, err) + } + + // Get subset of tests corresponding to shard. + indices, err := testutil.TestIndicesForShard(len(testCases)) + if err != nil { + fatalf("TestsForShard() failed: %v", err) + } + + // Resolve the absolute path for the binary. + testBin, err = filepath.Abs(testBin) + if err != nil { + fatalf("Abs() failed: %v", err) + } + + // Run the tests. + var tests []testing.InternalTest + for _, tci := range indices { + // Capture tc. + tc := testCases[tci] + tests = append(tests, testing.InternalTest{ + Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), + F: func(t *testing.T) { + if *platform == "native" { + // Run the test case on host. + runTestCaseNative(testBin, tc, t) + } else { + // Run the test case in runsc. + runTestCaseRunsc(testBin, tc, t) + } + }, + }) + } + + testing.Main(matchString, tests, nil, nil) +} diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl index 702522d86..2550b61a3 100644 --- a/test/runtimes/defs.bzl +++ b/test/runtimes/defs.bzl @@ -75,7 +75,6 @@ def runtime_test(name, **kwargs): "local", "manual", ], - size = "enormous", **kwargs ) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index ef299799e..0435f61a2 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -3,6 +3,8 @@ load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) +# Please keep syscall tests ordered alphabetically by name. + syscall_test( test = "//test/syscalls/linux:32bit_test", ) @@ -56,17 +58,7 @@ syscall_test( ) syscall_test( - test = "//test/syscalls/linux:socket_test", -) - -syscall_test( - test = "//test/syscalls/linux:socket_capability_test", -) - -syscall_test( - size = "large", - shard_count = most_shards, - test = "//test/syscalls/linux:socket_stress_test", + test = "//test/syscalls/linux:cgroup_test", ) syscall_test( @@ -244,6 +236,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:verity_ioctl_test", +) + +syscall_test( test = "//test/syscalls/linux:iptables_test", ) @@ -307,6 +303,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:mlock_test", +) + +syscall_test( size = "medium", shard_count = more_shards, test = "//test/syscalls/linux:mmap_test", @@ -604,6 +604,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:socket_capability_test", +) + +syscall_test( size = "medium", test = "//test/syscalls/linux:socket_domain_non_blocking_test", ) @@ -772,8 +776,17 @@ syscall_test( ) syscall_test( - # NOTE(b/116636318): Large sendmsg may stall a long time. - size = "enormous", + size = "large", + shard_count = most_shards, + test = "//test/syscalls/linux:socket_stress_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_test", +) + +syscall_test( + flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time. shard_count = more_shards, test = "//test/syscalls/linux:socket_unix_dgram_local_test", ) @@ -791,8 +804,7 @@ syscall_test( ) syscall_test( - # NOTE(b/116636318): Large sendmsg may stall a long time. - size = "enormous", + flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time. shard_count = more_shards, test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", ) @@ -995,3 +1007,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:processes_test", ) + +syscall_test( + test = "//test/syscalls/linux:verity_mount_test", +) diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc index 3c825477c..cbf1b4f05 100644 --- a/test/syscalls/linux/32bit.cc +++ b/test/syscalls/linux/32bit.cc @@ -22,15 +22,13 @@ #include "test/util/posix_error.h" #include "test/util/test_util.h" -#ifndef __x86_64__ -#error "This test is x86-64 specific." -#endif - namespace gvisor { namespace testing { namespace { +#ifdef __x86_64__ + constexpr char kInt3 = '\xcc'; constexpr char kInt80[2] = {'\xcd', '\x80'}; constexpr char kSyscall[2] = {'\x0f', '\x05'}; @@ -242,6 +240,8 @@ TEST(Call32Bit, Disallowed) { } } +#endif + } // namespace } // namespace testing diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 043ada583..94a582256 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -212,10 +212,7 @@ cc_binary( cc_binary( name = "32bit_test", testonly = 1, - srcs = select_arch( - amd64 = ["32bit.cc"], - arm64 = [], - ), + srcs = ["32bit.cc"], linkstatic = 1, deps = [ "@com_google_absl//absl/base:core_headers", @@ -1014,6 +1011,22 @@ cc_binary( ], ) +cc_binary( + name = "verity_ioctl_test", + testonly = 1, + srcs = ["verity_ioctl.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + gtest, + "//test/util:fs_util", + "//test/util:mount_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + cc_library( name = "iptables_types", testonly = 1, @@ -1304,6 +1317,20 @@ cc_binary( ) cc_binary( + name = "verity_mount_test", + testonly = 1, + srcs = ["verity_mount.cc"], + linkstatic = 1, + deps = [ + gtest, + "//test/util:capability_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "mremap_test", testonly = 1, srcs = ["mremap.cc"], @@ -1699,6 +1726,7 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "//test/util:mount_util", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -4205,3 +4233,25 @@ cc_binary( "//test/util:test_util", ], ) + +cc_binary( + name = "cgroup_test", + testonly = 1, + srcs = ["cgroup.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:cgroup_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:mount_util", + "@com_google_absl//absl/strings", + gtest, + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc index f65a14fb8..fe560cfc5 100644 --- a/test/syscalls/linux/accept_bind.cc +++ b/test/syscalls/linux/accept_bind.cc @@ -67,6 +67,42 @@ TEST_P(AllSocketPairTest, ListenDecreaseBacklog) { SyscallSucceeds()); } +TEST_P(AllSocketPairTest, ListenBacklogSizes) { + DisableSave ds; + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + int type; + socklen_t typelen = sizeof(type); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen), + SyscallSucceeds()); + + std::array<int, 3> backlogs = {-1, 0, 1}; + for (auto& backlog : backlogs) { + ASSERT_THAT(listen(sockets->first_fd(), backlog), SyscallSucceeds()); + + int expected_accepts = backlog; + if (backlog < 0) { + expected_accepts = 1024; + } + for (int i = 0; i < expected_accepts; i++) { + SCOPED_TRACE(absl::StrCat("i=", i)); + // Connect to the listening socket. + const FileDescriptor client = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0)); + ASSERT_THAT(connect(client.get(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + const FileDescriptor accepted = ASSERT_NO_ERRNO_AND_VALUE( + Accept(sockets->first_fd(), nullptr, nullptr)); + } + } +} + TEST_P(AllSocketPairTest, ListenWithoutBind) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(EINVAL)); @@ -285,8 +321,7 @@ TEST_P(AllSocketPairTest, AcceptValidAddrLen) { struct sockaddr_un addr = {}; socklen_t addr_len = sizeof(addr); ASSERT_THAT( - accepted = accept(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_len), + accepted = accept(sockets->first_fd(), AsSockAddr(&addr), &addr_len), SyscallSucceeds()); ASSERT_THAT(close(accepted), SyscallSucceeds()); } @@ -307,8 +342,7 @@ TEST_P(AllSocketPairTest, AcceptNegativeAddrLen) { // With a negative addr_len, accept returns EINVAL, struct sockaddr_un addr = {}; socklen_t addr_len = -1; - ASSERT_THAT(accept(sockets->first_fd(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_len), + ASSERT_THAT(accept(sockets->first_fd(), AsSockAddr(&addr), &addr_len), SyscallFailsWithErrno(EINVAL)); } @@ -499,10 +533,9 @@ TEST_P(AllSocketPairTest, UnboundSenderAddr) { struct sockaddr_storage addr; socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); + ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, + AsSockAddr(&addr), &addr_len), + SyscallSucceedsWithValue(sizeof(i))); EXPECT_EQ(addr_len, 0); } @@ -534,10 +567,9 @@ TEST_P(AllSocketPairTest, BoundSenderAddr) { struct sockaddr_storage addr; socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); + ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, + AsSockAddr(&addr), &addr_len), + SyscallSucceedsWithValue(sizeof(i))); EXPECT_EQ(addr_len, sockets->second_addr_len()); EXPECT_EQ( memcmp(&addr, sockets->second_addr(), @@ -573,10 +605,9 @@ TEST_P(AllSocketPairTest, BindAfterConnectSenderAddr) { struct sockaddr_storage addr; socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); + ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, + AsSockAddr(&addr), &addr_len), + SyscallSucceedsWithValue(sizeof(i))); EXPECT_EQ(addr_len, sockets->second_addr_len()); EXPECT_EQ( memcmp(&addr, sockets->second_addr(), @@ -612,10 +643,9 @@ TEST_P(AllSocketPairTest, BindAfterAcceptSenderAddr) { struct sockaddr_storage addr; socklen_t addr_len = sizeof(addr); - ASSERT_THAT( - RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&addr), &addr_len), - SyscallSucceedsWithValue(sizeof(i))); + ASSERT_THAT(RetryEINTR(recvfrom)(accepted_fd.get(), &i, sizeof(i), 0, + AsSockAddr(&addr), &addr_len), + SyscallSucceedsWithValue(sizeof(i))); EXPECT_EQ(addr_len, sockets->second_addr_len()); EXPECT_EQ( memcmp(&addr, sockets->second_addr(), diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc index 940c97285..cd0704334 100644 --- a/test/syscalls/linux/alarm.cc +++ b/test/syscalls/linux/alarm.cc @@ -36,7 +36,7 @@ void do_nothing_handler(int sig, siginfo_t* siginfo, void* arg) {} // No random save as the test relies on alarm timing. Cooperative save tests // already cover the save between alarm and read. -TEST(AlarmTest, Interrupt_NoRandomSave) { +TEST(AlarmTest, Interrupt) { int pipe_fds[2]; ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds()); @@ -71,7 +71,7 @@ void inc_alarms_handler(int sig, siginfo_t* siginfo, void* arg) { // No random save as the test relies on alarm timing. Cooperative save tests // already cover the save between alarm and read. -TEST(AlarmTest, Restart_NoRandomSave) { +TEST(AlarmTest, Restart) { alarms_received = 0; int pipe_fds[2]; @@ -114,7 +114,7 @@ TEST(AlarmTest, Restart_NoRandomSave) { // No random save as the test relies on alarm timing. Cooperative save tests // already cover the save between alarm and pause. -TEST(AlarmTest, SaSiginfo_NoRandomSave) { +TEST(AlarmTest, SaSiginfo) { // Use a signal handler that interrupts but does nothing rather than using the // default terminate action. struct sigaction sa; @@ -134,7 +134,7 @@ TEST(AlarmTest, SaSiginfo_NoRandomSave) { // No random save as the test relies on alarm timing. Cooperative save tests // already cover the save between alarm and pause. -TEST(AlarmTest, SaInterrupt_NoRandomSave) { +TEST(AlarmTest, SaInterrupt) { // Use a signal handler that interrupts but does nothing rather than using the // default terminate action. struct sigaction sa; diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc new file mode 100644 index 000000000..70ad5868f --- /dev/null +++ b/test/syscalls/linux/cgroup.cc @@ -0,0 +1,504 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// All tests in this file rely on being about to mount and unmount cgroupfs, +// which isn't expected to work, or be safe on a general linux system. + +#include <limits.h> +#include <sys/mount.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_split.h" +#include "test/util/capability_util.h" +#include "test/util/cgroup_util.h" +#include "test/util/mount_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +using ::testing::_; +using ::testing::Contains; +using ::testing::Ge; +using ::testing::Gt; +using ::testing::Key; +using ::testing::Not; + +std::vector<std::string> known_controllers = { + "cpu", "cpuset", "cpuacct", "job", "memory", +}; + +bool CgroupsAvailable() { + return IsRunningOnGvisor() && !IsRunningWithVFS1() && + TEST_CHECK_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)); +} + +TEST(Cgroup, MountSucceeds) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + EXPECT_NO_ERRNO(c.ContainsCallingProcess()); +} + +TEST(Cgroup, SeparateMounts) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + + for (const auto& ctl : known_controllers) { + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(ctl)); + EXPECT_NO_ERRNO(c.ContainsCallingProcess()); + } +} + +TEST(Cgroup, AllControllersImplicit) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + for (const auto& ctl : known_controllers) { + EXPECT_TRUE(cgroups_entries.contains(ctl)) + << absl::StreamFormat("ctl=%s", ctl); + } + EXPECT_EQ(cgroups_entries.size(), known_controllers.size()); +} + +TEST(Cgroup, AllControllersExplicit) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("all")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + for (const auto& ctl : known_controllers) { + EXPECT_TRUE(cgroups_entries.contains(ctl)) + << absl::StreamFormat("ctl=%s", ctl); + } + EXPECT_EQ(cgroups_entries.size(), known_controllers.size()); +} + +TEST(Cgroup, ProcsAndTasks) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + absl::flat_hash_set<pid_t> pids = ASSERT_NO_ERRNO_AND_VALUE(c.Procs()); + absl::flat_hash_set<pid_t> tids = ASSERT_NO_ERRNO_AND_VALUE(c.Tasks()); + + EXPECT_GE(tids.size(), pids.size()) << "Found more processes than threads"; + + // Pids should be a strict subset of tids. + for (auto it = pids.begin(); it != pids.end(); ++it) { + EXPECT_TRUE(tids.contains(*it)) + << absl::StreamFormat("Have pid %d, but no such tid", *it); + } +} + +TEST(Cgroup, ControllersMustBeInUniqueHierarchy) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + // Hierarchy #1: all controllers. + Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + // Hierarchy #2: memory. + // + // This should conflict since memory is already in hierarchy #1, and the two + // hierarchies have different sets of controllers, so this mount can't be a + // view into hierarchy #1. + EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _)) + << "Memory controller mounted on two hierarchies"; + EXPECT_THAT(m.MountCgroupfs("cpu"), PosixErrorIs(EBUSY, _)) + << "CPU controller mounted on two hierarchies"; +} + +TEST(Cgroup, UnmountFreesControllers) { + SKIP_IF(!CgroupsAvailable()); + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + // All controllers are now attached to all's hierarchy. Attempting new mount + // with any individual controller should fail. + EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _)) + << "Memory controller mounted on two hierarchies"; + + // Unmount the "all" hierarchy. This should enable any controller to be + // mounted on a new hierarchy again. + ASSERT_NO_ERRNO(m.Unmount(all)); + EXPECT_NO_ERRNO(m.MountCgroupfs("memory")); + EXPECT_NO_ERRNO(m.MountCgroupfs("cpu")); +} + +TEST(Cgroup, OnlyContainsControllerSpecificFiles) { + SKIP_IF(!CgroupsAvailable()); + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + EXPECT_THAT(Exists(mem.Relpath("memory.usage_in_bytes")), + IsPosixErrorOkAndHolds(true)); + // CPU files shouldn't exist in memory cgroups. + EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_period_us")), + IsPosixErrorOkAndHolds(false)); + EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_quota_us")), + IsPosixErrorOkAndHolds(false)); + EXPECT_THAT(Exists(mem.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(false)); + + Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_period_us")), + IsPosixErrorOkAndHolds(true)); + EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_quota_us")), + IsPosixErrorOkAndHolds(true)); + EXPECT_THAT(Exists(cpu.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(true)); + // Memory files shouldn't exist in cpu cgroups. + EXPECT_THAT(Exists(cpu.Relpath("memory.usage_in_bytes")), + IsPosixErrorOkAndHolds(false)); +} + +TEST(Cgroup, InvalidController) { + SKIP_IF(!CgroupsAvailable()); + + TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string mopts = "this-controller-is-invalid"; + EXPECT_THAT( + mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(Cgroup, MoptAllMustBeExclusive) { + SKIP_IF(!CgroupsAvailable()); + + TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string mopts = "all,cpu"; + EXPECT_THAT( + mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(MemoryCgroup, MemoryUsageInBytes) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + EXPECT_THAT(c.ReadIntegerControlFile("memory.usage_in_bytes"), + IsPosixErrorOkAndHolds(Gt(0))); +} + +TEST(CPUCgroup, ControlFilesHaveDefaultValues) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_quota_us"), + IsPosixErrorOkAndHolds(-1)); + EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_period_us"), + IsPosixErrorOkAndHolds(100000)); + EXPECT_THAT(c.ReadIntegerControlFile("cpu.shares"), + IsPosixErrorOkAndHolds(1024)); +} + +TEST(CPUAcctCgroup, CPUAcctUsage) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct")); + + const int64_t usage = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage")); + const int64_t usage_user = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_user")); + const int64_t usage_sys = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_sys")); + + EXPECT_GE(usage, 0); + EXPECT_GE(usage_user, 0); + EXPECT_GE(usage_sys, 0); + + EXPECT_GE(usage_user + usage_sys, usage); +} + +TEST(CPUAcctCgroup, CPUAcctStat) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct")); + + std::string stat = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuacct.stat")); + + // We're expecting the contents of "cpuacct.stat" to look similar to this: + // + // user 377986 + // system 220662 + + std::vector<absl::string_view> lines = + absl::StrSplit(stat, '\n', absl::SkipEmpty()); + ASSERT_EQ(lines.size(), 2); + + std::vector<absl::string_view> user_tokens = + StrSplit(lines[0], absl::ByChar(' ')); + EXPECT_EQ(user_tokens[0], "user"); + EXPECT_THAT(Atoi<int64_t>(user_tokens[1]), IsPosixErrorOkAndHolds(Ge(0))); + + std::vector<absl::string_view> sys_tokens = + StrSplit(lines[1], absl::ByChar(' ')); + EXPECT_EQ(sys_tokens[0], "system"); + EXPECT_THAT(Atoi<int64_t>(sys_tokens[1]), IsPosixErrorOkAndHolds(Ge(0))); +} + +// WriteAndVerifyControlValue attempts to write val to a cgroup file at path, +// and verify the value by reading it afterwards. +PosixError WriteAndVerifyControlValue(const Cgroup& c, std::string_view path, + int64_t val) { + RETURN_IF_ERRNO(c.WriteIntegerControlFile(path, val)); + ASSIGN_OR_RETURN_ERRNO(int64_t newval, c.ReadIntegerControlFile(path)); + if (newval != val) { + return PosixError( + EINVAL, + absl::StrFormat( + "Unexpected value for control file '%s': expected %d, got %d", path, + val, newval)); + } + return NoError(); +} + +TEST(JobCgroup, ReadWriteRead) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("job")); + + EXPECT_THAT(c.ReadIntegerControlFile("job.id"), IsPosixErrorOkAndHolds(0)); + EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", 1234)); + EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", -1)); + EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", LLONG_MIN)); + EXPECT_NO_ERRNO(WriteAndVerifyControlValue(c, "job.id", LLONG_MAX)); +} + +TEST(ProcCgroups, Empty) { + SKIP_IF(!CgroupsAvailable()); + + absl::flat_hash_map<std::string, CgroupsEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + // No cgroups mounted yet, we should have no entries. + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcCgroups, ProcCgroupsEntries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + + Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + absl::flat_hash_map<std::string, CgroupsEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_EQ(entries.size(), 1); + ASSERT_TRUE(entries.contains("memory")); + CgroupsEntry mem_e = entries["memory"]; + EXPECT_EQ(mem_e.subsys_name, "memory"); + EXPECT_GE(mem_e.hierarchy, 1); + // Expect a single root cgroup. + EXPECT_EQ(mem_e.num_cgroups, 1); + // Cgroups are currently always enabled when mounted. + EXPECT_TRUE(mem_e.enabled); + + // Add a second cgroup, and check for new entry. + + Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_EQ(entries.size(), 2); + EXPECT_TRUE(entries.contains("memory")); // Still have memory entry. + ASSERT_TRUE(entries.contains("cpu")); + CgroupsEntry cpu_e = entries["cpu"]; + EXPECT_EQ(cpu_e.subsys_name, "cpu"); + EXPECT_GE(cpu_e.hierarchy, 1); + EXPECT_EQ(cpu_e.num_cgroups, 1); + EXPECT_TRUE(cpu_e.enabled); + + // Separate hierarchies, since controllers were mounted separately. + EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy); +} + +TEST(ProcCgroups, UnmountRemovesEntries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup cg = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu,memory")); + absl::flat_hash_map<std::string, CgroupsEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_EQ(entries.size(), 2); + + ASSERT_NO_ERRNO(m.Unmount(cg)); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcPIDCgroup, Empty) { + SKIP_IF(!CgroupsAvailable()); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcPIDCgroup, Entries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_EQ(entries.size(), 1); + PIDCgroupEntry mem_e = entries["memory"]; + EXPECT_GE(mem_e.hierarchy, 1); + EXPECT_EQ(mem_e.controllers, "memory"); + EXPECT_EQ(mem_e.path, "/"); + + Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_EQ(entries.size(), 2); + EXPECT_TRUE(entries.contains("memory")); // Still have memory entry. + PIDCgroupEntry cpu_e = entries["cpu"]; + EXPECT_GE(cpu_e.hierarchy, 1); + EXPECT_EQ(cpu_e.controllers, "cpu"); + EXPECT_EQ(cpu_e.path, "/"); + + // Separate hierarchies, since controllers were mounted separately. + EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy); +} + +TEST(ProcPIDCgroup, UnmountRemovesEntries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_GT(entries.size(), 0); + + ASSERT_NO_ERRNO(m.Unmount(all)); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcCgroup, PIDCgroupMatchesCgroups) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + + CgroupsEntry cgroup_mem = cgroups_entries["memory"]; + PIDCgroupEntry pid_mem = pid_entries["memory"]; + + EXPECT_EQ(cgroup_mem.hierarchy, pid_mem.hierarchy); + + CgroupsEntry cgroup_cpu = cgroups_entries["cpu"]; + PIDCgroupEntry pid_cpu = pid_entries["cpu"]; + + EXPECT_EQ(cgroup_cpu.hierarchy, pid_cpu.hierarchy); + EXPECT_NE(cgroup_mem.hierarchy, cgroup_cpu.hierarchy); + EXPECT_NE(pid_mem.hierarchy, pid_cpu.hierarchy); +} + +TEST(ProcCgroup, MultiControllerHierarchy) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory,cpu")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + + CgroupsEntry mem_e = cgroups_entries["memory"]; + CgroupsEntry cpu_e = cgroups_entries["cpu"]; + + // Both controllers should have the same hierarchy ID. + EXPECT_EQ(mem_e.hierarchy, cpu_e.hierarchy); + + absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + + // Expecting an entry listing both controllers, that matches the previous + // hierarchy ID. Note that the controllers are listed in alphabetical order. + PIDCgroupEntry pid_e = pid_entries["cpu,memory"]; + EXPECT_EQ(pid_e.hierarchy, mem_e.hierarchy); +} + +TEST(ProcCgroup, ProcfsReportsCgroupfsMountOptions) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + // Hierarchy with multiple controllers. + Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory,cpu")); + // Hierarchy with a single controller. + Cgroup c2 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct")); + + const std::vector<ProcMountsEntry> mounts = + ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountsEntries()); + + for (auto const& e : mounts) { + if (e.mount_point == c1.Path()) { + auto mopts = ParseMountOptions(e.mount_opts); + EXPECT_THAT(mopts, Contains(Key("memory"))); + EXPECT_THAT(mopts, Contains(Key("cpu"))); + EXPECT_THAT(mopts, Not(Contains(Key("cpuacct")))); + } + + if (e.mount_point == c2.Path()) { + auto mopts = ParseMountOptions(e.mount_opts); + EXPECT_THAT(mopts, Contains(Key("cpuacct"))); + EXPECT_THAT(mopts, Not(Contains(Key("cpu")))); + EXPECT_THAT(mopts, Not(Contains(Key("memory")))); + } + } + + const std::vector<ProcMountInfoEntry> mountinfo = + ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountInfoEntries()); + + for (auto const& e : mountinfo) { + if (e.mount_point == c1.Path()) { + auto mopts = ParseMountOptions(e.super_opts); + EXPECT_THAT(mopts, Contains(Key("memory"))); + EXPECT_THAT(mopts, Contains(Key("cpu"))); + EXPECT_THAT(mopts, Not(Contains(Key("cpuacct")))); + } + + if (e.mount_point == c2.Path()) { + auto mopts = ParseMountOptions(e.super_opts); + EXPECT_THAT(mopts, Contains(Key("cpuacct"))); + EXPECT_THAT(mopts, Not(Contains(Key("cpu")))); + EXPECT_THAT(mopts, Not(Contains(Key("memory")))); + } + } +} + +} // namespace +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/chdir.cc b/test/syscalls/linux/chdir.cc index 3182c228b..3c64b9eab 100644 --- a/test/syscalls/linux/chdir.cc +++ b/test/syscalls/linux/chdir.cc @@ -41,8 +41,8 @@ TEST(ChdirTest, Success) { TEST(ChdirTest, PermissionDenied) { // Drop capabilities that allow us to override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc index 8233df0f8..dd82c5fb1 100644 --- a/test/syscalls/linux/chmod.cc +++ b/test/syscalls/linux/chmod.cc @@ -33,7 +33,7 @@ namespace { TEST(ChmodTest, ChmodFileSucceeds) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -43,8 +43,8 @@ TEST(ChmodTest, ChmodFileSucceeds) { TEST(ChmodTest, ChmodDirSucceeds) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const std::string fileInDir = NewTempAbsPathInDir(dir.path()); @@ -53,9 +53,9 @@ TEST(ChmodTest, ChmodDirSucceeds) { EXPECT_THAT(open(fileInDir.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); } -TEST(ChmodTest, FchmodFileSucceeds_NoRandomSave) { +TEST(ChmodTest, FchmodFileSucceeds) { // Drop capabilities that allow us to file directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); int fd; @@ -70,10 +70,10 @@ TEST(ChmodTest, FchmodFileSucceeds_NoRandomSave) { EXPECT_THAT(open(file.path().c_str(), O_RDWR), SyscallFailsWithErrno(EACCES)); } -TEST(ChmodTest, FchmodDirSucceeds_NoRandomSave) { +TEST(ChmodTest, FchmodDirSucceeds) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); int fd; @@ -118,7 +118,7 @@ TEST(ChmodTest, FchmodDirWithOpath) { TEST(ChmodTest, FchmodatWithOpath) { SKIP_IF(IsRunningWithVFS1()); // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -140,7 +140,7 @@ TEST(ChmodTest, FchmodatNotDir) { TEST(ChmodTest, FchmodatFileAbsolutePath) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -150,8 +150,8 @@ TEST(ChmodTest, FchmodatFileAbsolutePath) { TEST(ChmodTest, FchmodatDirAbsolutePath) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -167,7 +167,7 @@ TEST(ChmodTest, FchmodatDirAbsolutePath) { TEST(ChmodTest, FchmodatFile) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -188,8 +188,8 @@ TEST(ChmodTest, FchmodatFile) { TEST(ChmodTest, FchmodatDir) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -212,7 +212,7 @@ TEST(ChmodTest, FchmodatDir) { SyscallFailsWithErrno(EACCES)); } -TEST(ChmodTest, ChmodDowngradeWritability_NoRandomSave) { +TEST(ChmodTest, ChmodDowngradeWritability) { auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); int fd; @@ -227,8 +227,8 @@ TEST(ChmodTest, ChmodDowngradeWritability_NoRandomSave) { TEST(ChmodTest, ChmodFileToNoPermissionsSucceeds) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); @@ -238,7 +238,7 @@ TEST(ChmodTest, ChmodFileToNoPermissionsSucceeds) { SyscallFailsWithErrno(EACCES)); } -TEST(ChmodTest, FchmodDowngradeWritability_NoRandomSave) { +TEST(ChmodTest, FchmodDowngradeWritability) { auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); int fd; @@ -252,10 +252,10 @@ TEST(ChmodTest, FchmodDowngradeWritability_NoRandomSave) { EXPECT_THAT(close(fd), SyscallSucceeds()); } -TEST(ChmodTest, FchmodFileToNoPermissionsSucceeds_NoRandomSave) { +TEST(ChmodTest, FchmodFileToNoPermissionsSucceeds) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666)); diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc index ff0d39343..b0c1b6f4a 100644 --- a/test/syscalls/linux/chown.cc +++ b/test/syscalls/linux/chown.cc @@ -91,9 +91,7 @@ using Chown = class ChownParamTest : public ::testing::TestWithParam<Chown> {}; TEST_P(ChownParamTest, ChownFileSucceeds) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_CHOWN))) { - ASSERT_NO_ERRNO(SetCapability(CAP_CHOWN, false)); - } + AutoCapability cap(CAP_CHOWN, false); const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -135,9 +133,7 @@ TEST_P(ChownParamTest, ChownFilePermissionDenied) { // thread won't be able to open some log files after the test ends. ScopedThread([&] { // Drop privileges. - if (HaveCapability(CAP_CHOWN).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_CHOWN, false)); - } + AutoCapability cap(CAP_CHOWN, false); // Change EUID and EGID. // diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 1d0d584cd..32860aa21 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -117,7 +117,7 @@ TEST(DevTest, ReadDevNull) { } // Do not allow random save as it could lead to partial reads. -TEST(DevTest, ReadDevZero_NoRandomSave) { +TEST(DevTest, ReadDevZero) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_RDONLY)); diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index 8a72ef10a..af3d27894 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -39,6 +39,15 @@ namespace { constexpr int kFDsPerEpoll = 3; constexpr uint64_t kMagicConstant = 0x0102030405060708; +#ifndef SYS_epoll_pwait2 +#define SYS_epoll_pwait2 441 +#endif + +int epoll_pwait2(int fd, struct epoll_event* events, int maxevents, + const struct timespec* timeout, const sigset_t* sigset) { + return syscall(SYS_epoll_pwait2, fd, events, maxevents, timeout, sigset); +} + TEST(EpollTest, AllWritable) { auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); std::vector<FileDescriptor> eventfds; @@ -115,7 +124,7 @@ TEST(EpollTest, LastNonWritable) { } } -TEST(EpollTest, Timeout_NoRandomSave) { +TEST(EpollTest, Timeout) { auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); std::vector<FileDescriptor> eventfds; for (int i = 0; i < kFDsPerEpoll; i++) { @@ -144,6 +153,50 @@ TEST(EpollTest, Timeout_NoRandomSave) { EXPECT_GT(ms_elapsed(begin, end), kTimeoutMs - 1); } +TEST(EpollTest, EpollPwait2Timeout) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + // 200 milliseconds. + constexpr int kTimeoutNs = 200000000; + struct timespec timeout; + timeout.tv_sec = 0; + timeout.tv_nsec = 0; + struct timespec begin; + struct timespec end; + struct epoll_event result[kFDsPerEpoll]; + + std::vector<FileDescriptor> eventfds; + for (int i = 0; i < kFDsPerEpoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, + kMagicConstant + i)); + } + + // Pass valid arguments so that the syscall won't be blocked indefinitely + // nor return errno EINVAL. + // + // The syscall returns immediately when timeout is zero, + // even if no events are available. + SKIP_IF(!IsRunningOnGvisor() && + epoll_pwait2(epollfd.get(), result, kFDsPerEpoll, &timeout, nullptr) < + 0 && + errno == ENOSYS); + + { + const DisableSave ds; // Timing-related. + EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &begin), SyscallSucceeds()); + + timeout.tv_nsec = kTimeoutNs; + ASSERT_THAT(RetryEINTR(epoll_pwait2)(epollfd.get(), result, kFDsPerEpoll, + &timeout, nullptr), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(clock_gettime(CLOCK_MONOTONIC, &end), SyscallSucceeds()); + } + + // Check the lower bound on the timeout. Checking for an upper bound is + // fragile because Linux can overrun the timeout due to scheduling delays. + EXPECT_GT(ns_elapsed(begin, end), kTimeoutNs - 1); +} + void* writer(void* arg) { int fd = *reinterpret_cast<int*>(arg); uint64_t tmp = 1; @@ -290,7 +343,7 @@ TEST(EpollTest, Oneshot) { SyscallSucceedsWithValue(0)); } -TEST(EpollTest, EdgeTriggered_NoRandomSave) { +TEST(EpollTest, EdgeTriggered) { // Test edge-triggered entry: make it edge-triggered, first wait should // return it, second one should time out, make it writable again, third wait // should return it, fourth wait should timeout. diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc index dc794415e..8202d35fa 100644 --- a/test/syscalls/linux/eventfd.cc +++ b/test/syscalls/linux/eventfd.cc @@ -175,7 +175,7 @@ TEST(EventfdTest, SpliceFromPipePartialSucceeds) { } // NotifyNonZero is inherently racy, so random save is disabled. -TEST(EventfdTest, NotifyNonZero_NoRandomSave) { +TEST(EventfdTest, NotifyNonZero) { // Waits will time out at 10 seconds. constexpr int kEpollTimeoutMs = 10000; // Create an eventfd descriptor. diff --git a/test/syscalls/linux/fchdir.cc b/test/syscalls/linux/fchdir.cc index c6675802d..0383f3f85 100644 --- a/test/syscalls/linux/fchdir.cc +++ b/test/syscalls/linux/fchdir.cc @@ -46,8 +46,8 @@ TEST(FchdirTest, InvalidFD) { TEST(FchdirTest, PermissionDenied) { // Drop capabilities that allow us to override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index b286e84fe..fd387aa45 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -205,7 +205,7 @@ TEST_F(FlockTest, TestSharedLockFailExclusiveHolderNonblocking) { void trivial_handler(int signum) {} -TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking_NoRandomSave) { +TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking) { const DisableSave ds; // Timing-related. // This test will verify that a shared lock is denied while @@ -262,7 +262,7 @@ TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderNonblocking) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } -TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking_NoRandomSave) { +TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking) { const DisableSave ds; // Timing-related. // This test will verify that an exclusive lock is denied while @@ -499,7 +499,7 @@ TEST_F(FlockTest, TestDupFdFollowedByLock) { // NOTE: These blocking tests are not perfect. Unfortunately it's very hard to // determine if a thread was actually blocked in the kernel so we're forced // to use timing. -TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks_NoRandomSave) { +TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks) { // This test will verify that although LOCK_NB isn't specified // two different fds can obtain shared locks without blocking. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds()); @@ -539,7 +539,7 @@ TEST_F(FlockTest, BlockingLockNoBlockingForSharedLocks_NoRandomSave) { EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); } -TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive_NoRandomSave) { +TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive) { // This test will verify that if someone holds a shared lock any attempt to // obtain an exclusive lock will result in blocking. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_SH), SyscallSucceeds()); @@ -576,7 +576,7 @@ TEST_F(FlockTest, BlockingLockFirstSharedSecondExclusive_NoRandomSave) { EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); } -TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared_NoRandomSave) { +TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared) { // This test will verify that if someone holds an exclusive lock any attempt // to obtain a shared lock will result in blocking. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds()); @@ -613,7 +613,7 @@ TEST_F(FlockTest, BlockingLockFirstExclusiveSecondShared_NoRandomSave) { EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds()); } -TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) { +TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive) { // This test will verify that if someone holds an exclusive lock any attempt // to obtain another exclusive lock will result in blocking. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX), SyscallSucceeds()); diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc index c47567b4e..79b0596c4 100644 --- a/test/syscalls/linux/fpsig_fork.cc +++ b/test/syscalls/linux/fpsig_fork.cc @@ -44,6 +44,8 @@ namespace { #define SET_FP0(var) SET_FPREG(var, d0) #endif +#define DEFAULT_MXCSR 0x1f80 + int parent, child; void sigusr1(int s, siginfo_t* siginfo, void* _uc) { @@ -57,6 +59,12 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) { uint64_t got; GET_FP0(got); TEST_CHECK_MSG(val == got, "Basic FP check failed in sigusr1()"); + +#ifdef __x86_64 + uint32_t mxcsr; + __asm__("STMXCSR %0" : "=m"(mxcsr)); + TEST_CHECK_MSG(mxcsr == DEFAULT_MXCSR, "Unexpected mxcsr"); +#endif } TEST(FPSigTest, Fork) { @@ -125,6 +133,55 @@ TEST(FPSigTest, Fork) { } } +#ifdef __x86_64__ +TEST(FPSigTest, ForkWithZeroMxcsr) { + parent = getpid(); + pid_t parent_tid = gettid(); + + struct sigaction sa = {}; + sigemptyset(&sa.sa_mask); + sa.sa_flags = SA_SIGINFO; + sa.sa_sigaction = sigusr1; + ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); + + // The control bits of the MXCSR register are callee-saved (preserved across + // calls), while the status bits are caller-saved (not preserved). + uint32_t expected = 0, origin; + __asm__("STMXCSR %0" : "=m"(origin)); + __asm__("LDMXCSR %0" : : "m"(expected)); + + asm volatile( + "movl %[killnr], %%eax;" + "movl %[parent], %%edi;" + "movl %[tid], %%esi;" + "movl %[sig], %%edx;" + "syscall;" + : + : [killnr] "i"(__NR_tgkill), [parent] "rm"(parent), + [tid] "rm"(parent_tid), [sig] "i"(SIGUSR1) + : "rax", "rdi", "rsi", "rdx", + // Clobbered by syscall. + "rcx", "r11"); + + uint32_t got; + __asm__("STMXCSR %0" : "=m"(got)); + __asm__("LDMXCSR %0" : : "m"(origin)); + + if (getpid() == parent) { // Parent. + int status; + ASSERT_THAT(waitpid(child, &status, 0), SyscallSucceedsWithValue(child)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } + + // TEST_CHECK_MSG since this may run in the child. + TEST_CHECK_MSG(expected == got, "Bad mxcsr value"); + + if (getpid() != parent) { // Child. + _exit(0); + } +} +#endif + } // namespace } // namespace testing diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index 90b1f0508..859f92b75 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -220,7 +220,7 @@ TEST_P(PrivateAndSharedFutexTest, Wait_ZeroBitset) { SyscallFailsWithErrno(EINVAL)); } -TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, Wake1) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -240,7 +240,7 @@ TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) { EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1)); } -TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, Wake0) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -261,7 +261,7 @@ TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) { EXPECT_THAT(futex_wake(IsPrivate(), &a, 0), SyscallSucceedsWithValue(1)); } -TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WakeAll) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -282,7 +282,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) { SyscallSucceedsWithValue(kThreads)); } -TEST_P(PrivateAndSharedFutexTest, WakeSome_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WakeSome) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -331,7 +331,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeSome_NoRandomSave) { EXPECT_EQ(timedout, kThreads - kWokenThreads); } -TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -346,7 +346,7 @@ TEST_P(PrivateAndSharedFutexTest, WaitBitset_Wake_NoRandomSave) { EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1)); } -TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -361,7 +361,7 @@ TEST_P(PrivateAndSharedFutexTest, Wait_WakeBitset_NoRandomSave) { SyscallSucceedsWithValue(1)); } -TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -379,7 +379,7 @@ TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetMatch_NoRandomSave) { SyscallSucceedsWithValue(1)); } -TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -401,7 +401,7 @@ TEST_P(PrivateAndSharedFutexTest, WaitBitset_WakeBitsetNoMatch_NoRandomSave) { SyscallSucceedsWithValue(0)); } -TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue); @@ -428,7 +428,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeOpCondSuccess_NoRandomSave) { EXPECT_EQ(b, kInitialValue + 2); } -TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); std::atomic<int> b = ATOMIC_VAR_INIT(kInitialValue); @@ -457,7 +457,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeOpCondFailure_NoRandomSave) { EXPECT_EQ(b, kInitialValue + 2); } -TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon) { auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr()); @@ -484,7 +484,7 @@ TEST_P(PrivateAndSharedFutexTest, NoWakeInterprocessPrivateAnon_NoRandomSave) { << " status " << status; } -TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak) { // Use a futex on a non-stack mapping so we can be sure that the child process // below isn't the one that breaks copy-on-write. auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( @@ -520,7 +520,7 @@ TEST_P(PrivateAndSharedFutexTest, WakeAfterCOWBreak_NoRandomSave) { EXPECT_THAT(futex_wake(IsPrivate(), ptr, 1), SyscallSucceedsWithValue(1)); } -TEST_P(PrivateAndSharedFutexTest, WakeWrongKind_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, WakeWrongKind) { constexpr int kInitialValue = 1; std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue); @@ -584,7 +584,7 @@ TEST(PrivateFutexTest, WakeOp0Xor) { EXPECT_EQ(a, 0b0110); } -TEST(SharedFutexTest, WakeInterprocessSharedAnon_NoRandomSave) { +TEST(SharedFutexTest, WakeInterprocessSharedAnon) { auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED)); auto const ptr = static_cast<std::atomic<int>*>(mapping.ptr()); @@ -615,7 +615,7 @@ TEST(SharedFutexTest, WakeInterprocessSharedAnon_NoRandomSave) { << " status " << status; } -TEST(SharedFutexTest, WakeInterprocessFile_NoRandomSave) { +TEST(SharedFutexTest, WakeInterprocessFile) { auto const file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); ASSERT_THAT(truncate(file.path().c_str(), kPageSize), SyscallSucceeds()); auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); @@ -661,7 +661,7 @@ TEST_P(PrivateAndSharedFutexTest, PIBasic) { EXPECT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallFailsWithErrno(EPERM)); } -TEST_P(PrivateAndSharedFutexTest, PIConcurrency_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, PIConcurrency) { DisableSave ds; // Too many syscalls. std::atomic<int> a = ATOMIC_VAR_INIT(0); @@ -717,7 +717,7 @@ TEST_P(PrivateAndSharedFutexTest, PITryLock) { ASSERT_THAT(futex_unlock_pi(IsPrivate(), &a), SyscallSucceeds()); } -TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { +TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency) { DisableSave ds; // Too many syscalls. std::atomic<int> a = ATOMIC_VAR_INIT(0); diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index a88c89e20..f6b78989b 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -1156,7 +1156,7 @@ TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) { EXPECT_TRUE(events.empty()); } -TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) { +TEST(Inotify, ChmodGeneratesAttribEvent) { const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); @@ -1999,7 +1999,7 @@ TEST(Inotify, Exec) { // // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. -TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { +TEST(Inotify, IncludeUnlinkedFile) { const DisableSave ds; const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -2052,7 +2052,7 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { // // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. -TEST(Inotify, ExcludeUnlink_NoRandomSave) { +TEST(Inotify, ExcludeUnlink) { const DisableSave ds; // TODO(gvisor.dev/issue/1624): This test fails on VFS1. SKIP_IF(IsRunningWithVFS1()); @@ -2093,7 +2093,7 @@ TEST(Inotify, ExcludeUnlink_NoRandomSave) { // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. -TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) { +TEST(Inotify, ExcludeUnlinkDirectory) { // TODO(gvisor.dev/issue/1624): This test fails on VFS1. Remove once VFS1 is // deleted. SKIP_IF(IsRunningWithVFS1()); @@ -2138,7 +2138,7 @@ TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) { // // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. -TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) { +TEST(Inotify, ExcludeUnlinkMultipleChildren) { // Inotify does not work properly with hard links in gofer and overlay fs. SKIP_IF(IsRunningOnGvisor() && !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); @@ -2184,7 +2184,7 @@ TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) { // // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. -TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { +TEST(Inotify, ExcludeUnlinkInodeEvents) { // TODO(gvisor.dev/issue/1624): Fails on VFS1. SKIP_IF(IsRunningWithVFS1()); @@ -2284,7 +2284,7 @@ TEST(Inotify, OneShot) { // This test helps verify that the lock order of filesystem and inotify locks // is respected when inotify instances and watch targets are concurrently being // destroyed. -TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) { +TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock) { const DisableSave ds; // Too many syscalls. // A file descriptor protected by a mutex. This ensures that while a @@ -2350,7 +2350,7 @@ TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) { // This test helps verify that the lock order of filesystem and inotify locks // is respected when adding/removing watches occurs concurrently with the // removal of their targets. -TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) { +TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock) { const DisableSave ds; // Too many syscalls. // Set up inotify instances. @@ -2405,7 +2405,7 @@ TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) { // This test helps verify that the lock order of filesystem and inotify locks // is respected when many inotify events and filesystem operations occur // simultaneously. -TEST(InotifyTest, NotifyNoDeadlock_NoRandomSave) { +TEST(InotifyTest, NotifyNoDeadlock) { const DisableSave ds; // Too many syscalls. const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index e397d5f57..ac113e6da 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -215,7 +215,7 @@ int TestSIGALRMToMainThread() { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. -TEST(ItimerTest, DeliversSIGALRMToMainThread_NoRandomSave) { +TEST(ItimerTest, DeliversSIGALRMToMainThread) { pid_t child; int execve_errno; auto kill = ASSERT_NO_ERRNO_AND_VALUE( @@ -266,7 +266,7 @@ int TestSIGPROFFairness(absl::Duration sleep) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. -TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { +TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive) { // On the KVM and ptrace platforms, switches between sentry and application // context are sometimes extremely slow, causing the itimer to send SIGPROF to // a thread that either already has one pending or has had SIGPROF delivered, @@ -301,7 +301,7 @@ TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) { // Random save/restore is disabled as it introduces additional latency and // unpredictable distribution patterns. -TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) { +TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle) { // See comment in DeliversSIGPROFToThreadsRoughlyFairlyActive. const auto gvisor_platform = GvisorPlatform(); SKIP_IF(gvisor_platform == Platform::kKVM || diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc index 94aea4077..867a4513b 100644 --- a/test/syscalls/linux/memory_accounting.cc +++ b/test/syscalls/linux/memory_accounting.cc @@ -83,7 +83,7 @@ TEST(MemoryAccounting, AnonAccountingPreservedOnSaveRestore) { uint64_t anon_after_alloc = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo()); EXPECT_THAT(anon_after_alloc, - EquivalentWithin(anon_initial + map_bytes, 0.03)); + EquivalentWithin(anon_initial + map_bytes, 0.04)); // We have many implicit S/R cycles from scraping /proc/meminfo throughout the // test, but throw an explicit S/R in here as well. @@ -91,7 +91,7 @@ TEST(MemoryAccounting, AnonAccountingPreservedOnSaveRestore) { // Usage should remain the same across S/R. uint64_t anon_after_sr = ASSERT_NO_ERRNO_AND_VALUE(AnonUsageFromMeminfo()); - EXPECT_THAT(anon_after_sr, EquivalentWithin(anon_after_alloc, 0.03)); + EXPECT_THAT(anon_after_sr, EquivalentWithin(anon_after_alloc, 0.04)); } } // namespace diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index 11fbfa5c5..36504fe6d 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -72,8 +72,8 @@ TEST_F(MkdirTest, HonorsUmask2) { TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds()); auto dir = JoinPath(dirname_.c_str(), "foo"); @@ -84,8 +84,8 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { TEST_F(MkdirTest, DirAlreadyExists) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); auto dir = JoinPath(dirname_.c_str(), "foo"); diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc index 78ac96bed..dfa5b7133 100644 --- a/test/syscalls/linux/mlock.cc +++ b/test/syscalls/linux/mlock.cc @@ -114,9 +114,7 @@ TEST(MlockTest, Fork) { } TEST(MlockTest, RlimitMemlockZero) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( @@ -127,9 +125,7 @@ TEST(MlockTest, RlimitMemlockZero) { } TEST(MlockTest, RlimitMemlockInsufficient) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize)); auto const mapping = ASSERT_NO_ERRNO_AND_VALUE( @@ -255,9 +251,7 @@ TEST(MapLockedTest, Basic) { } TEST(MapLockedTest, RlimitMemlockZero) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); EXPECT_THAT( @@ -266,9 +260,7 @@ TEST(MapLockedTest, RlimitMemlockZero) { } TEST(MapLockedTest, RlimitMemlockInsufficient) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, kPageSize)); EXPECT_THAT( @@ -298,9 +290,7 @@ TEST(MremapLockedTest, RlimitMemlockZero) { MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); EXPECT_TRUE(IsPageMlocked(mapping.addr())); - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE(ScopedSetSoftRlimit(RLIMIT_MEMLOCK, 0)); void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(), @@ -315,9 +305,7 @@ TEST(MremapLockedTest, RlimitMemlockInsufficient) { MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_LOCKED)); EXPECT_TRUE(IsPageMlocked(mapping.addr())); - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_IPC_LOCK))) { - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_LOCK, false)); - } + AutoCapability cap(CAP_IPC_LOCK, false); Cleanup reset_rlimit = ASSERT_NO_ERRNO_AND_VALUE( ScopedSetSoftRlimit(RLIMIT_MEMLOCK, mapping.len())); void* addr = mremap(mapping.ptr(), mapping.len(), 2 * mapping.len(), diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index 15b645fb7..3c7311782 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -26,6 +26,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "test/util/capability_util.h" @@ -44,6 +45,10 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Contains; +using ::testing::Pair; + TEST(MountTest, MountBadFilesystem) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); @@ -63,9 +68,7 @@ TEST(MountTest, MountInvalidTarget) { TEST(MountTest, MountPermDenied) { // Clear CAP_SYS_ADMIN. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - } + AutoCapability cap(CAP_SYS_ADMIN, false); // Linux expects a valid target before checking capability. auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -345,6 +348,37 @@ TEST(MountTest, RenameRemoveMountPoint) { ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); } +TEST(MountTest, MountInfo) { + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto const mount = ASSERT_NO_ERRNO_AND_VALUE( + Mount("", dir.path(), "tmpfs", MS_NOEXEC, "mode=0123", 0)); + const std::vector<ProcMountsEntry> mounts = + ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountsEntries()); + for (const auto& e : mounts) { + if (e.mount_point == dir.path()) { + EXPECT_EQ(e.fstype, "tmpfs"); + auto mopts = ParseMountOptions(e.mount_opts); + EXPECT_THAT(mopts, AnyOf(Contains(Pair("mode", "0123")), + Contains(Pair("mode", "123")))); + } + } + + const std::vector<ProcMountInfoEntry> mountinfo = + ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountInfoEntries()); + + for (auto const& e : mountinfo) { + if (e.mount_point == dir.path()) { + EXPECT_EQ(e.fstype, "tmpfs"); + auto mopts = ParseMountOptions(e.super_opts); + EXPECT_THAT(mopts, AnyOf(Contains(Pair("mode", "0123")), + Contains(Pair("mode", "123")))); + } + } +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index e65ffee8f..ab9d19fef 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -431,9 +431,9 @@ TEST_F(OpenTest, CanTruncateReadOnly) { // If we don't have read permission on the file, opening with // O_TRUNC should fail. -TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission_NoRandomSave) { +TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); const DisableSave ds; // Permissions are dropped. ASSERT_THAT(chmod(test_file_name_.c_str(), S_IRUSR | S_IRGRP), @@ -452,7 +452,7 @@ TEST_F(OpenTest, CanTruncateReadOnlyNoWritePermission_NoRandomSave) { // If we don't have read permission but have write permission, opening O_WRONLY // and O_TRUNC should succeed. -TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) { +TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission) { const DisableSave ds; // Permissions are dropped. EXPECT_THAT(fchmod(test_file_fd_.get(), S_IWUSR | S_IWGRP), @@ -473,8 +473,8 @@ TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) { } TEST_F(OpenTest, CanTruncateWithStrangePermissions) { - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const DisableSave ds; // Permissions are dropped. std::string path = NewTempAbsPath(); // Create a file without user permissions. @@ -510,8 +510,8 @@ TEST_F(OpenTest, OpenWithStrangeFlags) { TEST_F(OpenTest, OpenWithOpath) { SKIP_IF(IsRunningWithVFS1()); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const DisableSave ds; // Permissions are dropped. std::string path = NewTempAbsPath(); diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 46f41de50..177bda54d 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -52,7 +52,7 @@ TEST(CreateTest, CreateAtFile) { EXPECT_THAT(close(fd), SyscallSucceeds()); } -TEST(CreateTest, HonorsUmask_NoRandomSave) { +TEST(CreateTest, HonorsUmask) { const DisableSave ds; // file cannot be re-opened as writable. auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); TempUmask mask(0222); @@ -93,7 +93,8 @@ TEST(CreateTest, CreatFileWithOTruncAndReadOnly) { TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // always override directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); + auto parent = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555)); auto file = JoinPath(parent.path(), "foo"); @@ -119,12 +120,12 @@ TEST(CreateTest, OpenCreateROThenRW) { EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1)); } -TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) { +TEST(CreateTest, ChmodReadToWriteBetweenOpens) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be // cleared for the same reason. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); @@ -149,10 +150,10 @@ TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) { EXPECT_EQ(c, 'x'); } -TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) { +TEST(CreateTest, ChmodWriteToReadBetweenOpens) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0200)); @@ -177,7 +178,7 @@ TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) { EXPECT_EQ(c, 'x'); } -TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) { +TEST(CreateTest, CreateWithReadFlagNotAllowedByMode) { // The only time we can open a file with flags forbidden by its permissions // is when we are creating the file. We cannot re-open with the same flags, // so we cannot restore an fd obtained from such an operation. @@ -186,8 +187,8 @@ TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be // cleared for the same reason. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); // Create and open a file with read flag but without read permissions. const std::string path = NewTempAbsPath(); @@ -204,7 +205,7 @@ TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) { EXPECT_EQ(c, 'x'); } -TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode_NoRandomSave) { +TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode) { // The only time we can open a file with flags forbidden by its permissions // is when we are creating the file. We cannot re-open with the same flags, // so we cannot restore an fd obtained from such an operation. @@ -212,7 +213,7 @@ TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode_NoRandomSave) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // override file read/write permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); // Create and open a file with write flag but without write permissions. const std::string path = NewTempAbsPath(); diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index d25be0e30..72080a272 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -440,11 +440,7 @@ TEST_P(RawPacketTest, SetSocketRecvBuf) { ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), SyscallSucceeds()); - // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc index 13afa0eaf..223ddc0c8 100644 --- a/test/syscalls/linux/partial_bad_buffer.cc +++ b/test/syscalls/linux/partial_bad_buffer.cc @@ -320,7 +320,7 @@ PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) { // EFAULT. It also verifies that passing a buffer which is made up of 2 // pages one valid and one guard page succeeds as long as the write is // for exactly the size of 1 page. -TEST_F(PartialBadBufferTest, SendMsgTCP_NoRandomSave) { +TEST_F(PartialBadBufferTest, SendMsgTCP) { // FIXME(b/171436815): Netstack save/restore is broken. const DisableSave ds; diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc index 999c8ab6b..8b78e4b16 100644 --- a/test/syscalls/linux/ping_socket.cc +++ b/test/syscalls/linux/ping_socket.cc @@ -35,7 +35,7 @@ namespace { // // We disable both random/cooperative S/R for this test as it makes way too many // syscalls. -TEST(PingSocket, ICMPPortExhaustion_NoRandomSave) { +TEST(PingSocket, ICMPPortExhaustion) { DisableSave ds; { diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 01ccbdcd2..96c454485 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -399,7 +399,7 @@ TEST_P(PipeTest, BlockPartialWriteClosed) { t.Join(); } -TEST_P(PipeTest, ReadFromClosedFd_NoRandomSave) { +TEST_P(PipeTest, ReadFromClosedFd) { SKIP_IF(!CreateBlocking()); absl::Notification notify; diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc index 6f9a9498c..5ce7e8c8d 100644 --- a/test/syscalls/linux/poll.cc +++ b/test/syscalls/linux/poll.cc @@ -57,7 +57,7 @@ TEST_F(PollTest, ZeroTimeout) { // If random S/R interrupts the poll, SIGALRM may be delivered before poll // restarts, causing the poll to hang forever. -TEST_F(PollTest, NegativeTimeout_NoRandomSave) { +TEST_F(PollTest, NegativeTimeout) { // Negative timeout mean wait forever so set a timer. SetTimer(absl::Milliseconds(100)); EXPECT_THAT(poll(nullptr, 0, -1), SyscallFailsWithErrno(EINTR)); diff --git a/test/syscalls/linux/ppoll.cc b/test/syscalls/linux/ppoll.cc index 8245a11e8..7f7d69731 100644 --- a/test/syscalls/linux/ppoll.cc +++ b/test/syscalls/linux/ppoll.cc @@ -76,7 +76,7 @@ TEST_F(PpollTest, ZeroTimeout) { // If random S/R interrupts the ppoll, SIGALRM may be delivered before ppoll // restarts, causing the ppoll to hang forever. -TEST_F(PpollTest, NoTimeout_NoRandomSave) { +TEST_F(PpollTest, NoTimeout) { // When there's no timeout, ppoll may never return so set a timer. SetTimer(absl::Milliseconds(100)); // See that we get interrupted by the timer. diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index f675dc430..19a57d353 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -184,10 +184,8 @@ TEST(PrctlTest, PDeathSig) { // This test is to validate that calling prctl with PR_SET_MM without the // CAP_SYS_RESOURCE returns EPERM. TEST(PrctlTest, InvalidPrSetMM) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) { - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE, - false)); // Drop capability to test below. - } + // Drop capability to test below. + AutoCapability cap(CAP_SYS_RESOURCE, false); ASSERT_THAT(prctl(PR_SET_MM, 0, 0, 0, 0), SyscallFailsWithErrno(EPERM)); } diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index c74990ba1..0a09259a3 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -144,7 +144,7 @@ TEST_F(Pread64Test, Overflow) { SyscallFailsWithErrno(EINVAL)); } -TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) { +TEST(Pread64TestNoTempFile, CantReadSocketPair) { int sock_fds[2]; EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds()); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 493042dfc..24928d876 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -65,6 +65,7 @@ #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/memory_util.h" +#include "test/util/mount_util.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/proc_util.h" @@ -1200,6 +1201,15 @@ TEST(ProcSelfCwd, Absolute) { EXPECT_EQ(exe[0], '/'); } +// Sanity check that /proc/cmdline is present. +TEST(ProcCmdline, IsPresent) { + SKIP_IF(IsRunningWithVFS1()); + + std::string proc_cmdline = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/cmdline")); + ASSERT_FALSE(proc_cmdline.empty()); +} + // Sanity check for /proc/cpuinfo fields that must be present. TEST(ProcCpuinfo, RequiredFieldsArePresent) { std::string proc_cpuinfo = @@ -1629,7 +1639,7 @@ TEST(ProcPidStatusTest, StateRunning) { IsPosixErrorOkAndHolds(Contains(Pair("State", "R (running)")))); } -TEST(ProcPidStatusTest, StateSleeping_NoRandomSave) { +TEST(ProcPidStatusTest, StateSleeping) { // Starts a child process that blocks and checks that State is sleeping. auto res = WithSubprocess( [&](int pid) -> PosixError { @@ -1848,8 +1858,8 @@ TEST(ProcPidSymlink, SubprocessRunning) { } TEST(ProcPidSymlink, SubprocessZombied) { - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); char buf[1]; @@ -2251,7 +2261,7 @@ TEST(ProcTask, VerifyTaskDir) { TEST(ProcTask, TaskDirCannotBeDeleted) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); EXPECT_THAT(rmdir("/proc/self/task"), SyscallFails()); EXPECT_THAT(rmdir(absl::StrCat("/proc/self/task/", getpid()).c_str()), @@ -2468,6 +2478,19 @@ TEST(ProcSelfMountinfo, RequiredFieldsArePresent) { R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)"))); } +TEST(ProcSelfMountinfo, ContainsProcfsEntry) { + const std::vector<ProcMountInfoEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountInfoEntries()); + bool found = false; + for (const auto& e : entries) { + if (e.fstype == "proc") { + found = true; + break; + } + } + EXPECT_TRUE(found); +} + // Check that /proc/self/mounts looks something like a real mounts file. TEST(ProcSelfMounts, RequiredFieldsArePresent) { auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts")); @@ -2479,6 +2502,19 @@ TEST(ProcSelfMounts, RequiredFieldsArePresent) { ContainsRegex(R"(\S+ /proc \S+ rw\S* [0-9]+ [0-9]+\s)"))); } +TEST(ProcSelfMounts, ContainsProcfsEntry) { + const std::vector<ProcMountsEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountsEntries()); + bool found = false; + for (const auto& e : entries) { + if (e.fstype == "proc") { + found = true; + break; + } + } + EXPECT_TRUE(found); +} + void CheckDuplicatesRecursively(std::string path) { std::vector<std::string> child_dirs; @@ -2671,6 +2707,14 @@ TEST(Proc, Statfs) { EXPECT_EQ(st.f_namelen, NAME_MAX); } +// Tests that /proc/[pid]/fd/[num] can resolve to a path inside /proc. +TEST(Proc, ResolveSymlinkToProc) { + const auto proc = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/cmdline", 0)); + const auto path = JoinPath("/proc/self/fd/", absl::StrCat(proc.get())); + const auto target = ASSERT_NO_ERRNO_AND_VALUE(ReadLink(path)); + EXPECT_EQ(target, JoinPath("/proc/", absl::StrCat(getpid()), "/cmdline")); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 20f1dc305..04fecc02e 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -189,7 +189,7 @@ PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp, EINVAL, absl::StrCat("failed to find ", type, "/", item, " in:", snmp)); } -TEST(ProcNetSnmp, TcpReset_NoRandomSave) { +TEST(ProcNetSnmp, TcpReset) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. DisableSave ds; @@ -231,7 +231,7 @@ TEST(ProcNetSnmp, TcpReset_NoRandomSave) { EXPECT_EQ(oldAttemptFails, newAttemptFails - 1); } -TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { +TEST(ProcNetSnmp, TcpEstab) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. DisableSave ds; @@ -263,9 +263,8 @@ TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { // Get the port bound by the listening socket. socklen_t addrlen = sizeof(sin); - ASSERT_THAT( - getsockname(s_listen.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen), - SyscallSucceeds()); + ASSERT_THAT(getsockname(s_listen.get(), AsSockAddr(&sin), &addrlen), + SyscallSucceeds()); FileDescriptor s_connect = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0)); @@ -326,7 +325,7 @@ TEST(ProcNetSnmp, TcpEstab_NoRandomSave) { EXPECT_EQ(oldEstabResets, newEstabResets - 2); } -TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { +TEST(ProcNetSnmp, UdpNoPorts) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. DisableSave ds; @@ -360,7 +359,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) { EXPECT_EQ(oldNoPorts, newNoPorts - 1); } -TEST(ProcNetSnmp, UdpIn_NoRandomSave) { +TEST(ProcNetSnmp, UdpIn) { // TODO(gvisor.dev/issue/866): epsocket metrics are not savable. const DisableSave ds; @@ -384,9 +383,8 @@ TEST(ProcNetSnmp, UdpIn_NoRandomSave) { SyscallSucceeds()); // Get the port bound by the server socket. socklen_t addrlen = sizeof(sin); - ASSERT_THAT( - getsockname(server.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen), - SyscallSucceeds()); + ASSERT_THAT(getsockname(server.get(), AsSockAddr(&sin), &addrlen), + SyscallSucceeds()); FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); @@ -421,14 +419,14 @@ TEST(ProcNetSnmp, CheckNetStat) { int name_count = 0; int value_count = 0; std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n'); - for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) { + for (size_t i = 0; i + 1 < lines.size(); i += 2) { std::vector<absl::string_view> names = absl::StrSplit(lines[i], absl::ByAnyChar("\t ")); std::vector<absl::string_view> values = absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t ")); EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i] << "' and '" << lines[i + 1] << "'"; - for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) { + for (size_t j = 0; j < names.size() && j < values.size(); ++j) { if (names[j] == "TCPOrigDataSent" || names[j] == "TCPSynRetrans" || names[j] == "TCPDSACKRecv" || names[j] == "TCPDSACKOfoRecv") { ++name_count; @@ -458,14 +456,14 @@ TEST(ProcNetSnmp, CheckSnmp) { int name_count = 0; int value_count = 0; std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n'); - for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) { + for (size_t i = 0; i + 1 < lines.size(); i += 2) { std::vector<absl::string_view> names = absl::StrSplit(lines[i], absl::ByAnyChar("\t ")); std::vector<absl::string_view> values = absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t ")); EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i] << "' and '" << lines[i + 1] << "'"; - for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) { + for (size_t j = 0; j < names.size() && j < values.size(); ++j) { if (names[j] == "RetransSegs") { ++name_count; int64_t val; diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index d61d94309..f7ff65aad 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -182,7 +182,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { // Returns true on match, and sets 'match' to point to the matching entry. bool FindBy(std::vector<UnixEntry> entries, UnixEntry* match, std::function<bool(const UnixEntry&)> predicate) { - for (long unsigned int i = 0; i < entries.size(); ++i) { + for (size_t i = 0; i < entries.size(); ++i) { if (predicate(entries[i])) { *match = entries[i]; return true; @@ -201,15 +201,8 @@ TEST(ProcNetUnix, Exists) { const std::string content = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/unix")); const std::string header_line = StrCat(kProcNetUnixHeader, "\n"); - if (IsRunningOnGvisor()) { - // Should be just the header since we don't have any unix domain sockets - // yet. - EXPECT_EQ(content, header_line); - } else { - // However, on a general linux machine, we could have abitrary sockets on - // the system, so just check the header. - EXPECT_THAT(content, ::testing::StartsWith(header_line)); - } + // We could have abitrary sockets on the system, so just check the header. + EXPECT_THAT(content, ::testing::StartsWith(header_line)); } TEST(ProcNetUnix, FilesystemBindAcceptConnect) { @@ -223,9 +216,6 @@ TEST(ProcNetUnix, FilesystemBindAcceptConnect) { std::vector<UnixEntry> entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - if (IsRunningOnGvisor()) { - EXPECT_EQ(entries.size(), 2); - } // The server-side socket's path is listed in the socket entry... UnixEntry s1; @@ -247,9 +237,6 @@ TEST(ProcNetUnix, AbstractBindAcceptConnect) { std::vector<UnixEntry> entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - if (IsRunningOnGvisor()) { - EXPECT_EQ(entries.size(), 2); - } // The server-side socket's path is listed in the socket entry... UnixEntry s1; @@ -261,20 +248,12 @@ TEST(ProcNetUnix, AbstractBindAcceptConnect) { } TEST(ProcNetUnix, SocketPair) { - // Under gvisor, ensure a socketpair() syscall creates exactly 2 new - // entries. We have no way to verify this under Linux, as we have no control - // over socket creation on a general Linux machine. - SKIP_IF(!IsRunningOnGvisor()); - - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - ASSERT_EQ(entries.size(), 0); - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_STREAM).Create()); - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - EXPECT_EQ(entries.size(), 2); + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + EXPECT_GE(entries.size(), 2); } TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) { @@ -368,25 +347,12 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE( AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - // On gVisor, the only two UDS on the system are the ones we just created and - // we rely on this to locate the test socket entries in the remainder of the - // test. On a generic Linux system, we have no easy way to locate the - // corresponding entries, as they don't have an address yet. - if (IsRunningOnGvisor()) { - ASSERT_EQ(entries.size(), 2); - for (const auto& e : entries) { - ASSERT_EQ(e.state, SS_DISCONNECTING); - } - } - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), sockets->first_addr_size()), SyscallSucceeds()); - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); const std::string address = ExtractPath(sockets->first_addr()); UnixEntry bind_entry; ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); @@ -397,25 +363,12 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE( AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); - std::vector<UnixEntry> entries = - ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - // On gVisor, the only two UDS on the system are the ones we just created and - // we rely on this to locate the test socket entries in the remainder of the - // test. On a generic Linux system, we have no easy way to locate the - // corresponding entries, as they don't have an address yet. - if (IsRunningOnGvisor()) { - ASSERT_EQ(entries.size(), 2); - for (const auto& e : entries) { - ASSERT_EQ(e.state, SS_DISCONNECTING); - } - } - ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), sockets->first_addr_size()), SyscallSucceeds()); - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); const std::string address = ExtractPath(sockets->first_addr()); UnixEntry bind_entry; ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); @@ -423,22 +376,6 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), sockets->first_addr_size()), SyscallSucceeds()); - - entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); - - // Once again, we have no easy way to identify the connecting socket as it has - // no listed address. We can only identify the entry as the "non-bind socket - // entry" on gVisor, where we're guaranteed to have only the two entries we - // create during this test. - if (IsRunningOnGvisor()) { - ASSERT_EQ(entries.size(), 2); - UnixEntry connect_entry; - ASSERT_TRUE( - FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) { - return e.inode != bind_entry.inode; - })); - EXPECT_EQ(connect_entry.state, SS_CONNECTING); - } } } // namespace diff --git a/test/syscalls/linux/proc_pid_uid_gid_map.cc b/test/syscalls/linux/proc_pid_uid_gid_map.cc index af052a63c..c030592c8 100644 --- a/test/syscalls/linux/proc_pid_uid_gid_map.cc +++ b/test/syscalls/linux/proc_pid_uid_gid_map.cc @@ -203,8 +203,9 @@ TEST_P(ProcSelfUidGidMapTest, IdentityMapOwnID) { EXPECT_THAT( InNewUserNamespaceWithMapFD([&](int fd) { DenySelfSetgroups(); - TEST_PCHECK(static_cast<long unsigned int>( - write(fd, line.c_str(), line.size())) == line.size()); + size_t n; + TEST_PCHECK((n = write(fd, line.c_str(), line.size())) != -1); + TEST_CHECK(n == line.size()); }), IsPosixErrorOkAndHolds(0)); } @@ -221,8 +222,9 @@ TEST_P(ProcSelfUidGidMapTest, TrailingNewlineAndNULIgnored) { DenySelfSetgroups(); // The write should return the full size of the write, even though // characters after the NUL were ignored. - TEST_PCHECK(static_cast<long unsigned int>( - write(fd, line.c_str(), line.size())) == line.size()); + size_t n; + TEST_PCHECK((n = write(fd, line.c_str(), line.size())) != -1); + TEST_CHECK(n == line.size()); }), IsPosixErrorOkAndHolds(0)); } diff --git a/test/syscalls/linux/pselect.cc b/test/syscalls/linux/pselect.cc index 4e43c4d7f..e490a987d 100644 --- a/test/syscalls/linux/pselect.cc +++ b/test/syscalls/linux/pselect.cc @@ -88,7 +88,7 @@ TEST_F(PselectTest, ZeroTimeout) { // If random S/R interrupts the pselect, SIGALRM may be delivered before pselect // restarts, causing the pselect to hang forever. -TEST_F(PselectTest, NoTimeout_NoRandomSave) { +TEST_F(PselectTest, NoTimeout) { // When there's no timeout, pselect may never return so set a timer. SetTimer(absl::Milliseconds(100)); // See that we get interrupted by the timer. diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index d1d7c6f84..d519b65e6 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -175,7 +175,7 @@ TEST(PtraceTest, AttachSameThreadGroup) { TEST(PtraceTest, TraceParentNotAllowed) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) < 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const child_pid = fork(); if (child_pid == 0) { @@ -193,7 +193,7 @@ TEST(PtraceTest, TraceParentNotAllowed) { TEST(PtraceTest, TraceNonDescendantNotAllowed) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) < 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const tracee_pid = fork(); if (tracee_pid == 0) { @@ -259,7 +259,7 @@ TEST(PtraceTest, TraceNonDescendantWithCapabilityAllowed) { TEST(PtraceTest, TraceDescendantsAllowed) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) > 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use socket pair to communicate tids to this process from its grandchild. int sockets[2]; @@ -346,7 +346,7 @@ TEST(PtraceTest, PrctlSetPtracerInvalidPID) { TEST(PtraceTest, PrctlSetPtracerPID) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -410,7 +410,7 @@ TEST(PtraceTest, PrctlSetPtracerPID) { TEST(PtraceTest, PrctlSetPtracerAny) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -475,7 +475,7 @@ TEST(PtraceTest, PrctlSetPtracerAny) { TEST(PtraceTest, PrctlClearPtracer) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -543,7 +543,7 @@ TEST(PtraceTest, PrctlClearPtracer) { TEST(PtraceTest, PrctlReplacePtracer) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const unused_pid = fork(); if (unused_pid == 0) { @@ -633,7 +633,7 @@ TEST(PtraceTest, PrctlReplacePtracer) { // thread group leader is still around. TEST(PtraceTest, PrctlSetPtracerPersistsPastTraceeThreadExit) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -703,7 +703,7 @@ TEST(PtraceTest, PrctlSetPtracerPersistsPastTraceeThreadExit) { // even if the tracee thread is terminated. TEST(PtraceTest, PrctlSetPtracerPersistsPastLeaderExec) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -770,7 +770,7 @@ TEST(PtraceTest, PrctlSetPtracerPersistsPastLeaderExec) { // exec. TEST(PtraceTest, PrctlSetPtracerDoesNotPersistPastNonLeaderExec) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -904,7 +904,7 @@ TEST(PtraceTest, PrctlSetPtracerDoesNotPersistPastTracerThreadExit) { [[noreturn]] void RunPrctlSetPtracerDoesNotPersistPastTracerThreadExit( int tracee_tid, int fd) { - TEST_PCHECK(SetCapability(CAP_SYS_PTRACE, false).ok()); + AutoCapability cap(CAP_SYS_PTRACE, false); ScopedThread t([fd] { pid_t const tracer_tid = gettid(); @@ -1033,7 +1033,7 @@ TEST(PtraceTest, PrctlSetPtracerRespectsTracerThreadID) { // attached. TEST(PtraceTest, PrctlClearPtracerDoesNotAffectCurrentTracer) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Use sockets to synchronize between tracer and tracee. int sockets[2]; @@ -1118,7 +1118,7 @@ TEST(PtraceTest, PrctlClearPtracerDoesNotAffectCurrentTracer) { TEST(PtraceTest, PrctlNotInherited) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) != 1); - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); // Allow any ptracer. This should not affect the child processes. ASSERT_THAT(prctl(PR_SET_PTRACER, PR_SET_PTRACER_ANY), SyscallSucceeds()); @@ -1708,8 +1708,7 @@ INSTANTIATE_TEST_SUITE_P(TraceExec, PtraceExecveTest, ::testing::Bool()); // This test has expectations on when syscall-enter/exit-stops occur that are // violated if saving occurs, since saving interrupts all syscalls, causing // premature syscall-exit. -TEST(PtraceTest, - ExitWhenParentIsNotTracer_Syscall_TraceVfork_TraceVforkDone_NoRandomSave) { +TEST(PtraceTest, ExitWhenParentIsNotTracer_Syscall_TraceVfork_TraceVforkDone) { constexpr int kExitTraceeExitCode = 99; pid_t const child_pid = fork(); @@ -2006,7 +2005,7 @@ TEST(PtraceTest, Sysemu_PokeUser) { } // This test also cares about syscall-exit-stop. -TEST(PtraceTest, ERESTART_NoRandomSave) { +TEST(PtraceTest, ERESTART) { constexpr int kSigno = SIGUSR1; pid_t const child_pid = fork(); @@ -2303,7 +2302,7 @@ TEST(PtraceTest, SetYAMAPtraceScope) { EXPECT_STREQ(buf.data(), "0\n"); // Test that a child can attach to its parent when ptrace_scope is 0. - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, false)); + AutoCapability cap(CAP_SYS_PTRACE, false); pid_t const child_pid = fork(); if (child_pid == 0) { TEST_PCHECK(CheckPtraceAttach(getppid()) == 0); diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 32924466f..69616b400 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -514,10 +514,7 @@ TEST_P(RawSocketTest, SetSocketRecvBuf) { SyscallSucceeds()); // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } @@ -713,12 +710,7 @@ TEST_P(RawSocketTest, RecvBufLimits) { } // Now set the limit to min * 2. - int new_rcv_buf_sz = min * 4; - if (!IsRunningOnGvisor()) { - // Linux doubles the value specified so just set to min. - new_rcv_buf_sz = min * 2; - } - + int new_rcv_buf_sz = min * 2; ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, sizeof(new_rcv_buf_sz)), SyscallSucceeds()); diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index 2f25aceb2..8b3d02d97 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -177,10 +177,8 @@ TEST_F(RawHDRINCL, ConnectToLoopback) { SyscallSucceeds()); } -TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { - // FIXME(gvisor.dev/issue/3159): Test currently flaky. - SKIP_IF(true); - +// FIXME(gvisor.dev/issue/3159): Test currently flaky. +TEST_F(RawHDRINCL, DISABLED_SendWithoutConnectSucceeds) { struct iphdr hdr = LoopbackHeader(); ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), SyscallSucceedsWithValue(sizeof(hdr))); diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 087262535..7056342d7 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -97,7 +97,7 @@ TEST_F(ReadTest, DevNullReturnsEof) { const int kReadSize = 128 * 1024; // Do not allow random save as it could lead to partial reads. -TEST_F(ReadTest, CanReadFullyFromDevZero_NoRandomSave) { +TEST_F(ReadTest, CanReadFullyFromDevZero) { int fd; ASSERT_THAT(fd = open("/dev/zero", O_RDONLY), SyscallSucceeds()); diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index 86808d255..a50d98d21 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -267,7 +267,7 @@ TEST_F(ReadvTest, ReadvWithOpath) { // This test depends on the maximum extent of a single readv() syscall, so // we can't tolerate interruption from saving. -TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) { +TEST(ReadvTestNoFixture, TruncatedAtMax) { // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly // important in environments where automated profiling tools may start // ITIMER_PROF automatically. diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index b1a813de0..76a8da65f 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -259,8 +259,8 @@ TEST(RenameTest, DirectoryDoesNotOverwriteNonemptyDirectory) { TEST(RenameTest, FailsWhenOldParentNotWritable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); @@ -275,8 +275,8 @@ TEST(RenameTest, FailsWhenOldParentNotWritable) { TEST(RenameTest, FailsWhenNewParentNotWritable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); @@ -293,8 +293,8 @@ TEST(RenameTest, FailsWhenNewParentNotWritable) { // to overwrite. TEST(RenameTest, OverwriteFailsWhenNewParentNotWritable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto f1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); @@ -312,8 +312,8 @@ TEST(RenameTest, OverwriteFailsWhenNewParentNotWritable) { // because the user cannot determine if source exists. TEST(RenameTest, FileDoesNotExistWhenNewParentNotExecutable) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); // No execute permission. auto dir = ASSERT_NO_ERRNO_AND_VALUE( diff --git a/test/syscalls/linux/rlimits.cc b/test/syscalls/linux/rlimits.cc index 860f0f688..d31a2a880 100644 --- a/test/syscalls/linux/rlimits.cc +++ b/test/syscalls/linux/rlimits.cc @@ -41,9 +41,7 @@ TEST(RlimitTest, SetRlimitHigher) { TEST(RlimitTest, UnprivilegedSetRlimit) { // Drop privileges if necessary. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_RESOURCE))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_RESOURCE, false)); - } + AutoCapability cap(CAP_SYS_RESOURCE, false); struct rlimit rl = {}; rl.rlim_cur = 1000; diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc index be2364fb8..d74096ded 100644 --- a/test/syscalls/linux/select.cc +++ b/test/syscalls/linux/select.cc @@ -98,7 +98,7 @@ TEST_F(SelectTest, ZeroTimeout) { // If random S/R interrupts the select, SIGALRM may be delivered before select // restarts, causing the select to hang forever. -TEST_F(SelectTest, NoTimeout_NoRandomSave) { +TEST_F(SelectTest, NoTimeout) { // When there's no timeout, select may never return so set a timer. SetTimer(absl::Milliseconds(100)); // See that we get interrupted by the timer. @@ -118,7 +118,7 @@ TEST_F(SelectTest, InvalidTimeoutNegative) { // // If random S/R interrupts the select, SIGALRM may be delivered before select // restarts, causing the select to hang forever. -TEST_F(SelectTest, InterruptedBySignal_NoRandomSave) { +TEST_F(SelectTest, InterruptedBySignal) { absl::Duration duration(absl::Seconds(5)); struct timeval timeout = absl::ToTimeval(duration); SetTimer(absl::Milliseconds(100)); diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index 28f51a3bf..2ce8f836c 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -234,14 +234,6 @@ TEST(SemaphoreTest, SemTimedOpBlock) { AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); - ScopedThread th([&sem] { - absl::SleepFor(absl::Milliseconds(100)); - - struct sembuf buf = {}; - buf.sem_op = 1; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - }); - struct sembuf buf = {}; buf.sem_op = -1; struct timespec timeout = {}; @@ -295,7 +287,7 @@ TEST(SemaphoreTest, SemOpSimple) { // Tests that semaphore can be removed while there are waiters. // NoRandomSave: Test relies on timing that random save throws off. -TEST(SemaphoreTest, SemOpRemoveWithWaiter_NoRandomSave) { +TEST(SemaphoreTest, SemOpRemoveWithWaiter) { AutoSem sem(semget(IPC_PRIVATE, 2, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -543,7 +535,7 @@ TEST(SemaphoreTest, SemCtlGetPidFork) { TEST(SemaphoreTest, SemIpcSet) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -568,7 +560,7 @@ TEST(SemaphoreTest, SemIpcSet) { TEST(SemaphoreTest, SemCtlIpcStat) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); const uid_t kUid = getuid(); const gid_t kGid = getgid(); time_t start_time = time(nullptr); @@ -643,7 +635,7 @@ PosixErrorOr<int> WaitSemctl(int semid, int target, int cmd) { TEST(SemaphoreTest, SemopGetzcnt) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); // Create a write only semaphore set. AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -716,7 +708,7 @@ TEST(SemaphoreTest, SemopGetzcntOnSetRemoval) { EXPECT_THAT(semctl(semid, 0, GETZCNT), SyscallFailsWithErrno(EINVAL)); } -TEST(SemaphoreTest, SemopGetzcntOnSignal_NoRandomSave) { +TEST(SemaphoreTest, SemopGetzcntOnSignal) { AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); @@ -751,7 +743,7 @@ TEST(SemaphoreTest, SemopGetzcntOnSignal_NoRandomSave) { TEST(SemaphoreTest, SemopGetncnt) { // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); // Create a write only semaphore set. AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -821,7 +813,7 @@ TEST(SemaphoreTest, SemopGetncntOnSetRemoval) { EXPECT_THAT(semctl(semid, 0, GETNCNT), SyscallFailsWithErrno(EINVAL)); } -TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) { +TEST(SemaphoreTest, SemopGetncntOnSignal) { AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); ASSERT_EQ(semctl(sem.get(), 0, GETNCNT), 0); @@ -861,7 +853,7 @@ TEST(SemaphoreTest, IpcInfo) { std::set<int> sem_ids; struct seminfo info; // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); for (int i = 0; i < kLoops; i++) { AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); @@ -931,7 +923,7 @@ TEST(SemaphoreTest, SemInfo) { std::set<int> sem_ids; struct seminfo info; // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + AutoCapability cap(CAP_IPC_OWNER, false); for (int i = 0; i < kLoops; i++) { AutoSem sem(semget(IPC_PRIVATE, kSemSetSize, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 93b3a94f1..bea4ee71c 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -654,7 +654,7 @@ TEST(SendFileTest, SendFileToPipe) { SyscallSucceedsWithValue(kDataSize)); } -TEST(SendFileTest, SendFileToSelf_NoRandomSave) { +TEST(SendFileTest, SendFileToSelf) { int rawfd; ASSERT_THAT(rawfd = memfd_create("memfd", 0), SyscallSucceeds()); const FileDescriptor fd(rawfd); @@ -675,7 +675,7 @@ TEST(SendFileTest, SendFileToSelf_NoRandomSave) { static volatile int signaled = 0; void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } -TEST(SendFileTest, ToEventFDDoesNotSpin_NoRandomSave) { +TEST(SendFileTest, ToEventFDDoesNotSpin) { FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); // Write the maximum value of an eventfd to a file. diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc index 4f8afff15..21651a697 100644 --- a/test/syscalls/linux/sigtimedwait.cc +++ b/test/syscalls/linux/sigtimedwait.cc @@ -52,7 +52,7 @@ TEST(SigtimedwaitTest, InvalidTimeout) { // No random save as the test relies on alarm timing. Cooperative save tests // already cover the save between alarm and wait. -TEST(SigtimedwaitTest, AlarmReturnsAlarm_NoRandomSave) { +TEST(SigtimedwaitTest, AlarmReturnsAlarm) { struct itimerval itv = {}; itv.it_value.tv_sec = kAlarmSecs; const auto itimer_cleanup = @@ -69,7 +69,7 @@ TEST(SigtimedwaitTest, AlarmReturnsAlarm_NoRandomSave) { // No random save as the test relies on alarm timing. Cooperative save tests // already cover the save between alarm and wait. -TEST(SigtimedwaitTest, NullTimeoutReturnsEINTR_NoRandomSave) { +TEST(SigtimedwaitTest, NullTimeoutReturnsEINTR) { struct sigaction sa; sa.sa_sigaction = NoopHandler; sigfillset(&sa.sa_mask); diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index b616c2c87..7b966484d 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -47,7 +47,7 @@ TEST(SocketTest, ProtocolUnix) { {AF_UNIX, SOCK_SEQPACKET, PF_UNIX}, {AF_UNIX, SOCK_DGRAM, PF_UNIX}, }; - for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(tests); i++) { ASSERT_NO_ERRNO_AND_VALUE( Socket(tests[i].domain, tests[i].type, tests[i].protocol)); } @@ -60,7 +60,7 @@ TEST(SocketTest, ProtocolInet) { {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, {AF_INET, SOCK_STREAM, IPPROTO_TCP}, }; - for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { + for (size_t i = 0; i < ABSL_ARRAYSIZE(tests); i++) { ASSERT_NO_ERRNO_AND_VALUE( Socket(tests[i].domain, tests[i].type, tests[i].protocol)); } @@ -111,7 +111,7 @@ TEST(SocketTest, UnixSocketStatFS) { EXPECT_EQ(st.f_namelen, NAME_MAX); } -TEST(SocketTest, UnixSCMRightsOnlyPassedOnce_NoRandomSave) { +TEST(SocketTest, UnixSCMRightsOnlyPassedOnce) { const DisableSave ds; int sockets[2]; diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc index f8a0a80f2..3b108cbd3 100644 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -141,9 +141,8 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { endpoint.bind_to_device.c_str(), endpoint.bind_to_device.size() + 1), SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); // On the first bind we need to determine which port was bound. @@ -154,8 +153,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -168,7 +166,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { std::vector<std::unique_ptr<ScopedThread>> listen_threads( listener_fds.size()); - for (long unsigned int i = 0; i < listener_fds.size(); i++) { + for (size_t i = 0; i < listener_fds.size(); i++) { listen_threads[i] = absl::make_unique<ScopedThread>( [&listener_fds, &accept_counts, &connects_received, i, kConnectAttempts]() { @@ -207,10 +205,9 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { for (int32_t i = 0; i < kConnectAttempts; i++) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT( - RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(fd.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), SyscallSucceedsWithValue(sizeof(i))); @@ -221,7 +218,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { listen_thread->Join(); } // Check that connections are distributed correctly among listening sockets. - for (long unsigned int i = 0; i < accept_counts.size(); i++) { + for (size_t i = 0; i < accept_counts.size(); i++) { EXPECT_THAT( accept_counts[i], EquivalentWithin(static_cast<int>(kConnectAttempts * @@ -267,9 +264,8 @@ TEST_P(BindToDeviceDistributionTest, Udp) { endpoint.bind_to_device.c_str(), endpoint.bind_to_device.size() + 1), SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); // On the first bind we need to determine which port was bound. if (listener_fds.size() > 1) { @@ -279,8 +275,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) { // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -294,7 +289,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) { std::vector<std::unique_ptr<ScopedThread>> receiver_threads( listener_fds.size()); - for (long unsigned int i = 0; i < listener_fds.size(); i++) { + for (size_t i = 0; i < listener_fds.size(); i++) { receiver_threads[i] = absl::make_unique<ScopedThread>( [&listener_fds, &packets_per_socket, &packets_received, i]() { do { @@ -302,9 +297,9 @@ TEST_P(BindToDeviceDistributionTest, Udp) { socklen_t addrlen = sizeof(addr); int data; - auto ret = RetryEINTR(recvfrom)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen); + auto ret = + RetryEINTR(recvfrom)(listener_fds[i].get(), &data, sizeof(data), + 0, AsSockAddr(&addr), &addrlen); if (packets_received < kConnectAttempts) { ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); @@ -322,10 +317,10 @@ TEST_P(BindToDeviceDistributionTest, Udp) { // A response is required to synchronize with the main thread, // otherwise the main thread can send more than can fit into receive // queues. - EXPECT_THAT(RetryEINTR(sendto)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(data))); + EXPECT_THAT( + RetryEINTR(sendto)(listener_fds[i].get(), &data, sizeof(data), + 0, AsSockAddr(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); } while (packets_received < kConnectAttempts); // Shutdown all sockets to wake up other threads. @@ -339,8 +334,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) { FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), + AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceedsWithValue(sizeof(i))); int data; EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), @@ -352,7 +346,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) { receiver_thread->Join(); } // Check that packets are distributed correctly among listening sockets. - for (long unsigned int i = 0; i < packets_per_socket.size(); i++) { + for (size_t i = 0; i < packets_per_socket.size(); i++) { EXPECT_THAT( packets_per_socket[i], EquivalentWithin(static_cast<int>(kConnectAttempts * diff --git a/test/syscalls/linux/socket_capability.cc b/test/syscalls/linux/socket_capability.cc index 84b5b2b21..f75482aba 100644 --- a/test/syscalls/linux/socket_capability.cc +++ b/test/syscalls/linux/socket_capability.cc @@ -40,7 +40,7 @@ TEST(SocketTest, UnixConnectNeedsWritePerm) { // Drop capabilites that allow us to override permision checks. Otherwise if // the test is run as root, the connect below will bypass permission checks // and succeed unexpectedly. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); // Connect should fail without write perms. ASSERT_THAT(chmod(addr.sun_path, 0500), SyscallSucceeds()); diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 597b5bcb1..9a6b089f6 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -190,8 +190,7 @@ TEST_P(DualStackSocketTest, AddressOperations) { if (sockname) { sockaddr_storage sock_addr; socklen_t addrlen = sizeof(sock_addr); - ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr), - &addrlen), + ASSERT_THAT(getsockname(fd.get(), AsSockAddr(&sock_addr), &addrlen), SyscallSucceeds()); ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); @@ -200,24 +199,23 @@ TEST_P(DualStackSocketTest, AddressOperations) { if (operation == Operation::SendTo) { EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6); EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32)) - << OperationToString(operation) << " getsocknam=" - << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr)); + << OperationToString(operation) + << " getsocknam=" << GetAddrStr(AsSockAddr(&sock_addr)); EXPECT_NE(sock_addr_in6->sin6_port, 0); } else if (IN6_IS_ADDR_V4MAPPED( reinterpret_cast<const sockaddr_in6*>(addr_in) ->sin6_addr.s6_addr32)) { EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32)) - << OperationToString(operation) << " getsocknam=" - << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr)); + << OperationToString(operation) + << " getsocknam=" << GetAddrStr(AsSockAddr(&sock_addr)); } } if (peername) { sockaddr_storage peer_addr; socklen_t addrlen = sizeof(peer_addr); - ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr), - &addrlen), + ASSERT_THAT(getpeername(fd.get(), AsSockAddr(&peer_addr), &addrlen), SyscallSucceeds()); ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); @@ -227,8 +225,8 @@ TEST_P(DualStackSocketTest, AddressOperations) { EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED( reinterpret_cast<const sockaddr_in6*>(&peer_addr) ->sin6_addr.s6_addr32)) - << OperationToString(operation) << " getpeername=" - << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr)); + << OperationToString(operation) + << " getpeername=" << GetAddrStr(AsSockAddr(&peer_addr)); } } } @@ -265,16 +263,15 @@ void tcpSimpleConnectTest(TestAddress const& listener, Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; if (!unbound) { - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); } ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -284,8 +281,7 @@ void tcpSimpleConnectTest(TestAddress const& listener, Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -331,9 +327,9 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds()); @@ -341,8 +337,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); const uint16_t port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -357,8 +352,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { for (int i = 0; i < kBacklog; i++) { auto client = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(RetryEINTR(connect)(client.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(client.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); } @@ -380,15 +374,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -402,8 +395,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { for (int i = 0; i < kFDs; i++) { auto client = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(RetryEINTR(connect)(client.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(client.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds()); @@ -420,8 +412,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); ASSERT_THAT( - bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), + bind(new_listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Check that subsequent connection attempts receive a RST. @@ -431,8 +422,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdown) { for (int i = 0; i < kFDs; i++) { auto client = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(RetryEINTR(connect)(client.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(client.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallFailsWithErrno(ECONNREFUSED)); } @@ -452,15 +442,14 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -471,8 +460,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { for (int i = 0; i < kFDs; i++) { auto client = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); + int ret = connect(client.get(), AsSockAddr(&conn_addr), connector.addr_len); if (ret != 0) { EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); } @@ -484,93 +472,160 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) { } } -void TestListenWhileConnect(const TestParam& param, - void (*stopListen)(FileDescriptor&)) { +void TestHangupDuringConnect(const TestParam& param, + void (*hangup)(FileDescriptor&)) { TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; - constexpr int kBacklog = 2; - // Linux completes one more connection than the listen backlog argument. - // To ensure that there is at least one client connection that stays in - // connecting state, keep 2 more client connections than the listen backlog. - // gVisor differs in this behavior though, gvisor.dev/issue/3153. - constexpr int kClients = kBacklog + 2; + for (int i = 0; i < 100; i++) { + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), 0), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + // Connect asynchronously and immediately hang up the listener. + FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); + } + + hangup(listen_fd); + + // Wait for the connection to close. + struct pollfd pfd = { + .fd = client.get(), + }; + constexpr int kTimeout = 10000; + int n = poll(&pfd, 1, kTimeout); + ASSERT_GE(n, 0) << strerror(errno); + ASSERT_EQ(n, 1); + ASSERT_EQ(pfd.revents, POLLHUP | POLLERR); + ASSERT_EQ(close(client.release()), 0) << strerror(errno); + } +} + +TEST_P(SocketInetLoopbackTest, TCPListenCloseDuringConnect) { + TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(close(f.release()), SyscallSucceeds()); + }); +} + +TEST_P(SocketInetLoopbackTest, TCPListenShutdownDuringConnect) { + TestHangupDuringConnect(GetParam(), [](FileDescriptor& f) { + ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); + }); +} + +void TestListenHangupConnectingRead(const TestParam& param, + void (*hangup)(FileDescriptor&)) { + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); + // This test is only interested in deterministically getting a socket in + // connecting state. For that, we use a listen backlog of zero which would + // mean there is exactly one connection that gets established and is enqueued + // to the accept queue. We poll on the listener to ensure that is enqueued. + // After that the subsequent client connect will stay in connecting state as + // the accept queue is full. + constexpr int kBacklog = 0; ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - std::vector<FileDescriptor> clients; - for (int i = 0; i < kClients; i++) { - FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - clients.push_back(std::move(client)); - } + FileDescriptor established_client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(established_client.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Ensure that the accept queue has the completed connection. + constexpr int kTimeout = 10000; + pollfd pfd = { + .fd = listen_fd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + + FileDescriptor connecting_client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + // Keep the last client in connecting state. + int ret = connect(connecting_client.get(), AsSockAddr(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); } - stopListen(listen_fd); + hangup(listen_fd); - for (auto& client : clients) { - constexpr int kTimeout = 10000; + std::array<std::pair<int, int>, 2> sockets = { + std::make_pair(established_client.get(), ECONNRESET), + std::make_pair(connecting_client.get(), ECONNREFUSED), + }; + for (size_t i = 0; i < sockets.size(); i++) { + SCOPED_TRACE(absl::StrCat("i=", i)); + auto [fd, expected_errno] = sockets[i]; pollfd pfd = { - .fd = client.get(), - .events = POLLIN, + .fd = fd, }; - // When the listening socket is closed, then we expect the remote to reset - // the connection. - ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + // When the listening socket is closed, the peer would reset the connection. + EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + EXPECT_EQ(pfd.revents, POLLHUP | POLLERR); char c; - // Subsequent read can fail with: - // ECONNRESET: If the client connection was established and was reset by the - // remote. - // ECONNREFUSED: If the client connection failed to be established. - ASSERT_THAT(read(client.get(), &c, sizeof(c)), - AnyOf(SyscallFailsWithErrno(ECONNRESET), - SyscallFailsWithErrno(ECONNREFUSED))); - // The last client connection would be in connecting (SYN_SENT) state. - if (client.get() == clients[kClients - 1].get()) { - ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno); - } + EXPECT_THAT(read(fd, &c, sizeof(c)), SyscallFailsWithErrno(expected_errno)); } } -TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) { - TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { +TEST_P(SocketInetLoopbackTest, TCPListenCloseConnectingRead) { + TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) { ASSERT_THAT(close(f.release()), SyscallSucceeds()); }); } -TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { - TestListenWhileConnect(GetParam(), [](FileDescriptor& f) { +TEST_P(SocketInetLoopbackTest, TCPListenShutdownConnectingRead) { + TestListenHangupConnectingRead(GetParam(), [](FileDescriptor& f) { ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds()); }); } -// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/ +// TODO(b/157236388): Remove once bug is fixed. Test fails w/ // random save as established connections which can't be delivered to the accept // queue because the queue is full are not correctly delivered after restore // causing the last accept to timeout on the restore. -TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPAcceptBacklogSizes) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -580,21 +635,70 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + std::array<int, 3> backlogs = {-1, 0, 1}; + for (auto& backlog : backlogs) { + ASSERT_THAT(listen(listen_fd.get(), backlog), SyscallSucceeds()); + + int expected_accepts; + if (backlog < 0) { + expected_accepts = 1024; + } else { + expected_accepts = backlog + 1; + } + for (int i = 0; i < expected_accepts; i++) { + SCOPED_TRACE(absl::StrCat("i=", i)); + // Connect to the listening socket. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + const FileDescriptor accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + } + } +} + +// TODO(b/157236388): Remove once bug is fixed. Test fails w/ +// random save as established connections which can't be delivered to the accept +// queue because the queue is full are not correctly delivered after restore +// causing the last accept to timeout on the restore. +TEST_P(SocketInetLoopbackTest, TCPBacklog) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); constexpr int kBacklogSize = 2; ASSERT_THAT(listen(listen_fd.get(), kBacklogSize), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); int i = 0; while (1) { + SCOPED_TRACE(absl::StrCat("i=", i)); int ret; // Connect to the listening socket. @@ -602,8 +706,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ret = connect(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); + ret = connect(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len); if (ret != 0) { EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); pollfd pfd = { @@ -620,103 +723,130 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { i++; } + int client_conns = i; + int accepted_conns = 0; for (; i != 0; i--) { - // Accept the connection. - // - // We have to assign a name to the accepted socket, as unamed temporary - // objects are destructed upon full evaluation of the expression it is in, - // potentially causing the connecting socket to fail to shutdown properly. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + SCOPED_TRACE(absl::StrCat("i=", i)); + pollfd pfd = { + .fd = listen_fd.get(), + .events = POLLIN, + }; + // Look for incoming connections to accept. The last connect request could + // be established from the client side, but the ACK of the handshake could + // be dropped by the listener if the accept queue was filled up by the + // previous connect. + int ret; + ASSERT_THAT(ret = poll(&pfd, 1, 3000), SyscallSucceeds()); + if (ret == 0) break; + if (pfd.revents == POLLIN) { + // Accept the connection. + // + // We have to assign a name to the accepted socket, as unamed temporary + // objects are destructed upon full evaluation of the expression it is in, + // potentially causing the connecting socket to fail to shutdown properly. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + accepted_conns++; + } } + // We should accept at least listen backlog + 1 connections. As the stack is + // enqueuing established connections to the accept queue, newer SYNs could + // still be replied to causing those client connections would be accepted as + // we start dequeuing the queue. + ASSERT_GE(accepted_conns, kBacklogSize + 1); + ASSERT_GE(client_conns, accepted_conns); } -// Test if the stack completes atmost listen backlog number of client -// connections. It exercises the path of the stack that enqueues completed -// connections to accept queue vs new incoming SYNs. -TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) { - const auto& param = GetParam(); - const TestAddress& listener = param.listener; - const TestAddress& connector = param.connector; +// TODO(b/157236388): Remove once bug is fixed. Test fails w/ +// random save as established connections which can't be delivered to the accept +// queue because the queue is full are not correctly delivered after restore +// causing the last accept to timeout on the restore. +TEST_P(SocketInetLoopbackTest, TCPBacklogAcceptAll) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); constexpr int kBacklog = 1; - // Keep the number of client connections more than the listen backlog. - // Linux completes one more connection than the listen backlog argument. - // gVisor differs in this behavior though, gvisor.dev/issue/3153. - int kClients = kBacklog + 2; - if (IsRunningOnGvisor()) { - kClients--; - } + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); - // Run the following test for few iterations to test race between accept queue - // getting filled with incoming SYNs. - for (int num = 0; num < 10; num++) { - FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - socklen_t addrlen = listener.addr_len; - ASSERT_THAT( - getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - std::vector<FileDescriptor> clients; - // Issue multiple non-blocking client connects. - for (int i = 0; i < kClients; i++) { - FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } - clients.push_back(std::move(client)); + // Fill up the accept queue and trigger more client connections which would be + // waiting to be accepted. + std::array<FileDescriptor, kBacklog + 1> established_clients; + for (auto& fd : established_clients) { + fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(fd.get(), AsSockAddr(&conn_addr), connector.addr_len), + SyscallSucceeds()); + } + std::array<FileDescriptor, kBacklog> waiting_clients; + for (auto& fd : waiting_clients) { + fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(fd.get(), AsSockAddr(&conn_addr), connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); } + } - // Now that client connects are issued, wait for the accept queue to get - // filled and ensure no new client connection is completed. - for (int i = 0; i < kClients; i++) { - pollfd pfd = { - .fd = clients[i].get(), - .events = POLLOUT, - }; - if (i < kClients - 1) { - // Poll for client side connection completions with a large timeout. - // We cannot poll on the listener side without calling accept as poll - // stays level triggered with non-zero accept queue length. - // - // Client side poll would not guarantee that the completed connection - // has been enqueued in to the acccept queue, but the fact that the - // listener ACKd the SYN, means that it cannot complete any new incoming - // SYNs when it has already ACKd for > backlog number of SYNs. - ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1)) - << "num=" << num << " i=" << i << " kClients=" << kClients; - ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i; - } else { - // Now that we expect accept queue filled up, ensure that the last - // client connection never completes with a smaller poll timeout. - ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0)) - << "num=" << num << " i=" << i; - } + auto accept_connection = [&]() { + constexpr int kTimeout = 10000; + pollfd pfd = { + .fd = listen_fd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + // Accept the connection. + // + // We have to assign a name to the accepted socket, as unamed temporary + // objects are destructed upon full evaluation of the expression it is in, + // potentially causing the connecting socket to fail to shutdown properly. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + }; - ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0)) - << "num=" << num << " i=" << i; - } - clients.clear(); - // We close the listening side and open a new listener. We could instead - // drain the accept queue by calling accept() and reuse the listener, but - // that is racy as the retransmitted SYNs could get ACKd as we make room in - // the accept queue. - ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0)); + // Ensure that we accept all client connections. The waiting connections would + // get enqueued as we drain the accept queue. + for (int i = 0; i < std::size(established_clients); i++) { + SCOPED_TRACE(absl::StrCat("established clients i=", i)); + accept_connection(); + } + + // The waiting client connections could be in one of these 2 states: + // (1) SYN_SENT: if the SYN was dropped because accept queue was full + // (2) ESTABLISHED: if the listener sent back a SYNACK, but may have dropped + // the ACK from the client if the accept queue was full (send out a data to + // re-send that ACK, to address that case). + for (int i = 0; i < std::size(waiting_clients); i++) { + SCOPED_TRACE(absl::StrCat("waiting clients i=", i)); + constexpr int kTimeout = 10000; + pollfd pfd = { + .fd = waiting_clients[i].get(), + .events = POLLOUT, + }; + EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + EXPECT_EQ(pfd.revents, POLLOUT); + char c; + EXPECT_THAT(RetryEINTR(send)(waiting_clients[i].get(), &c, sizeof(c), 0), + SyscallSucceedsWithValue(sizeof(c))); + accept_connection(); } } @@ -728,7 +858,7 @@ TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) { // // TCP timers are not S/R today, this can cause this test to be flaky when run // under random S/R due to timer being reset on a restore. -TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPFinWait2Test) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -737,15 +867,14 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = @@ -763,8 +892,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -776,8 +904,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { sockaddr_storage conn_bound_addr; socklen_t conn_addrlen = connector.addr_len; ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - &conn_addrlen), + getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen), SyscallSucceeds()); // close the connecting FD to trigger FIN_WAIT2 on the connected fd. @@ -792,8 +919,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // be restarted causing the final bind/connect to fail. DisableSave ds; - ASSERT_THAT(bind(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen), SyscallFailsWithErrno(EADDRINUSE)); // Sleep for a little over the linger timeout to reduce flakiness in @@ -802,10 +928,9 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { ds.reset(); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), - SyscallSucceeds()); + ASSERT_THAT( + RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen), + SyscallSucceeds()); } // TCPLinger2TimeoutAfterClose creates a pair of connected sockets @@ -815,7 +940,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // // TCP timers are not S/R today, this can cause this test to be flaky when run // under random S/R due to timer being reset on a restore. -TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -824,15 +949,14 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = @@ -844,8 +968,7 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -857,8 +980,7 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { sockaddr_storage conn_bound_addr; socklen_t conn_addrlen = connector.addr_len; ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - &conn_addrlen), + getsockname(conn_fd.get(), AsSockAddr(&conn_bound_addr), &conn_addrlen), SyscallSucceeds()); // Disable cooperative saves after this point as TCP timers are not restored @@ -884,13 +1006,11 @@ TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(bind(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), + ASSERT_THAT(bind(conn_fd2.get(), AsSockAddr(&conn_bound_addr), conn_addrlen), SyscallSucceeds()); + ASSERT_THAT( + RetryEINTR(connect)(conn_fd2.get(), AsSockAddr(&conn_addr), conn_addrlen), + SyscallSucceeds()); } // TCPResetAfterClose creates a pair of connected sockets then closes @@ -906,15 +1026,14 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = @@ -926,8 +1045,7 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -975,15 +1093,14 @@ void setupTimeWaitClose(const TestAddress* listener, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); } - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(listen_addr), - listener->addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(listen_addr), listener->addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener->addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = @@ -1005,8 +1122,7 @@ void setupTimeWaitClose(const TestAddress* listener, sockaddr_storage conn_addr = connector->addr; ASSERT_NO_ERRNO(SetAddrPort(connector->family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector->addr_len), SyscallSucceeds()); @@ -1017,8 +1133,7 @@ void setupTimeWaitClose(const TestAddress* listener, // Get the address/port bound by the connecting socket. socklen_t conn_addrlen = connector->addr_len; ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(conn_bound_addr), - &conn_addrlen), + getsockname(conn_fd.get(), AsSockAddr(conn_bound_addr), &conn_addrlen), SyscallSucceeds()); FileDescriptor active_closefd, passive_closefd; @@ -1064,7 +1179,7 @@ void setupTimeWaitClose(const TestAddress* listener, // // Test re-binding of client and server bound addresses when the older // connection is in TIME_WAIT. -TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest) { auto const& param = GetParam(); sockaddr_storage listen_addr, conn_bound_addr; listen_addr = param.listener.addr; @@ -1075,19 +1190,18 @@ TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { // bound by the conn_fd as it never entered TIME_WAIT. const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), param.connector.addr_len), SyscallSucceeds()); FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - param.listener.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); } -TEST_P(SocketInetLoopbackTest, - TCPPassiveCloseNoTimeWaitReuseTest_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitReuseTest) { auto const& param = GetParam(); sockaddr_storage listen_addr, conn_bound_addr; listen_addr = param.listener.addr; @@ -1099,9 +1213,9 @@ TEST_P(SocketInetLoopbackTest, ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - param.listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), param.listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Now bind and connect new socket and verify that we can immediately rebind @@ -1111,7 +1225,7 @@ TEST_P(SocketInetLoopbackTest, ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), param.connector.addr_len), SyscallSucceeds()); @@ -1119,13 +1233,12 @@ TEST_P(SocketInetLoopbackTest, ASSERT_NO_ERRNO_AND_VALUE(AddrPort(param.listener.family(), listen_addr)); sockaddr_storage conn_addr = param.connector.addr; ASSERT_NO_ERRNO(SetAddrPort(param.connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), param.connector.addr_len), SyscallSucceeds()); } -TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest) { auto const& param = GetParam(); sockaddr_storage listen_addr, conn_bound_addr; listen_addr = param.listener.addr; @@ -1134,12 +1247,12 @@ TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), param.connector.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } -TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest) { auto const& param = GetParam(); sockaddr_storage listen_addr, conn_bound_addr; listen_addr = param.listener.addr; @@ -1150,7 +1263,7 @@ TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest_NoRandomSave) { ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + ASSERT_THAT(bind(conn_fd.get(), AsSockAddr(&conn_bound_addr), param.connector.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -1164,15 +1277,14 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); const uint16_t port = @@ -1190,8 +1302,7 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -1218,17 +1329,16 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. { socklen_t addrlen = listener.addr_len; ASSERT_THAT( - getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - &addrlen), + getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); } @@ -1244,8 +1354,7 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { // TODO(b/157236388): Reenable Cooperative S/R once bug is fixed. DisableSave ds; - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -1272,8 +1381,8 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { sockaddr_storage accept_addr; socklen_t addrlen = sizeof(accept_addr); - auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( - listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen)); + auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE( + Accept(listen_fd.get(), AsSockAddr(&accept_addr), &addrlen)); ASSERT_EQ(addrlen, listener.addr_len); // Wait for accept_fd to process the RST. @@ -1311,15 +1420,14 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { sockaddr_storage peer_addr; socklen_t addrlen = sizeof(peer_addr); // The socket is not connected anymore and should return ENOTCONN. - ASSERT_THAT(getpeername(accept_fd.get(), - reinterpret_cast<sockaddr*>(&peer_addr), &addrlen), + ASSERT_THAT(getpeername(accept_fd.get(), AsSockAddr(&peer_addr), &addrlen), SyscallFailsWithErrno(ENOTCONN)); } } // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not // saved. Enable S/R once issue is fixed. -TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPDeferAccept) { // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not // saved. Enable S/R issue is fixed. DisableSave ds; @@ -1332,15 +1440,14 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); const uint16_t port = @@ -1358,8 +1465,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -1401,7 +1507,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not // saved. Enable S/R once issue is fixed. -TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout) { // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not // saved. Enable S/R once issue is fixed. DisableSave ds; @@ -1414,15 +1520,14 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); const uint16_t port = @@ -1440,8 +1545,7 @@ TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), - reinterpret_cast<sockaddr*>(&conn_addr), + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceeds()); @@ -1507,9 +1611,9 @@ INSTANTIATE_TEST_SUITE_P( using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>; -// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is +// TODO(gvisor.dev/issue/940): Remove when portHint/stack.Seed is // saved/restored. -TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { +TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -1529,9 +1633,8 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); // On the first bind we need to determine which port was bound. @@ -1542,8 +1645,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -1601,10 +1703,9 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { for (int32_t i = 0; i < kConnectAttempts; i++) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT( - RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(fd.get(), AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceeds()); EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), SyscallSucceedsWithValue(sizeof(i))); @@ -1622,7 +1723,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -1641,9 +1742,8 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); // On the first bind we need to determine which port was bound. if (i != 0) { @@ -1653,8 +1753,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -1677,9 +1776,9 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { socklen_t addrlen = sizeof(addr); int data; - auto ret = RetryEINTR(recvfrom)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen); + auto ret = + RetryEINTR(recvfrom)(listener_fds[i].get(), &data, sizeof(data), + 0, AsSockAddr(&addr), &addrlen); if (packets_received < kConnectAttempts) { ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); @@ -1697,10 +1796,10 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { // A response is required to synchronize with the main thread, // otherwise the main thread can send more than can fit into receive // queues. - EXPECT_THAT(RetryEINTR(sendto)( - listener_fds[i].get(), &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(data))); + EXPECT_THAT( + RetryEINTR(sendto)(listener_fds[i].get(), &data, sizeof(data), + 0, AsSockAddr(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); } while (packets_received < kConnectAttempts); // Shutdown all sockets to wake up other threads. @@ -1713,10 +1812,10 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { for (int i = 0; i < kConnectAttempts; i++) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); - EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), - SyscallSucceedsWithValue(sizeof(i))); + EXPECT_THAT( + RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, AsSockAddr(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); int data; EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), SyscallSucceedsWithValue(sizeof(data))); @@ -1735,7 +1834,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } -TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -1757,9 +1856,8 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT( - bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(fd, AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); // On the first bind we need to determine which port was bound. if (i != 0) { @@ -1769,8 +1867,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; ASSERT_THAT( - getsockname(listener_fds[0].get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + getsockname(listener_fds[0].get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -1787,8 +1884,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { client_fds[i] = ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), + AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceedsWithValue(sizeof(i))); } ds.reset(); @@ -1797,8 +1893,7 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { // not been change after save/restore. for (int i = 0; i < kConnectAttempts; i++) { EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, - reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), + AsSockAddr(&conn_addr), connector.addr_len), SyscallSucceedsWithValue(sizeof(i))); } @@ -1826,9 +1921,8 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) { struct sockaddr_storage addr = {}; socklen_t addrlen = sizeof(addr); int data; - EXPECT_THAT(RetryEINTR(recvfrom)( - fd, &data, sizeof(data), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + EXPECT_THAT(RetryEINTR(recvfrom)(fd, &data, sizeof(data), 0, + AsSockAddr(&addr), &addrlen), SyscallSucceedsWithValue(sizeof(data))); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr)); @@ -1882,14 +1976,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) { sockaddr_storage addr_dual = test_addr_dual.addr; const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_dual.family(), param.type, 0)); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len), + SyscallSucceeds()); // Get the port that we bound. socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); @@ -1900,8 +1993,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); + int ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len); if (ret == -1 && errno == EADDRINUSE) { // Port may have been in use. ASSERT_LT(i, 100); // Give up after 100 tries. @@ -1916,8 +2008,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedLoopbackOnlyReservesV4) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); const FileDescriptor fd_v4 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), + ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // No need to try again. @@ -1934,14 +2025,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) { sockaddr_storage addr_dual = test_addr_dual.addr; const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_dual.family(), param.type, 0)); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len), + SyscallSucceeds()); // Get the port that we bound. socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); @@ -1952,8 +2042,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - int ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); + int ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len); if (ret == -1 && errno == EADDRINUSE) { // Port may have been in use. ASSERT_LT(i, 100); // Give up after 100 tries. @@ -1968,8 +2057,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedAnyOnlyReservesV4) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); const FileDescriptor fd_v4 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), + ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // No need to try again. @@ -1985,14 +2073,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { sockaddr_storage addr_dual = test_addr_dual.addr; const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len), + SyscallSucceeds()); // Get the port that we bound. socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); @@ -2003,8 +2090,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len), + ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 loopback on the same port with a v6 socket @@ -2015,10 +2101,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v4_mapped.family(), param.type, 0)); - ASSERT_THAT( - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); + ASSERT_THAT(bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped), + test_addr_v4_mapped.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 loopback on the same port with a v4 socket // fails. @@ -2027,8 +2112,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); const FileDescriptor fd_v4 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), + ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 any on the same port with a v4 socket @@ -2038,7 +2122,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v4_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any), test_addr_v4_any.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -2055,14 +2139,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len), + SyscallSucceeds()); // Get the port that we bound. socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); @@ -2076,7 +2159,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any), test_addr_v4_any.addr_len), SyscallSucceeds()); } @@ -2096,16 +2179,15 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); // Get the port that we bound. socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); @@ -2120,7 +2202,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any), test_addr_v4_any.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -2137,16 +2219,15 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, sockaddr_storage addr_dual = test_addr_dual.addr; const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); // Get the port that we bound. socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); @@ -2157,8 +2238,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len), + ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 loopback on the same port with a v6 socket @@ -2169,10 +2249,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v4_mapped.family(), param.type, 0)); - ASSERT_THAT( - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); + ASSERT_THAT(bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped), + test_addr_v4_mapped.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 loopback on the same port with a v4 socket // fails. @@ -2181,8 +2260,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); const FileDescriptor fd_v4 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), + ASSERT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 any on the same port with a v4 socket @@ -2192,7 +2270,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v4_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + ASSERT_THAT(bind(fd_v4_any.get(), AsSockAddr(&addr_v4_any), test_addr_v4_any.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -2209,14 +2287,13 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { EXPECT_THAT(setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), - test_addr_dual.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(fd_dual.get(), AsSockAddr(&addr_dual), test_addr_dual.addr_len), + SyscallSucceeds()); // Get the port that we bound. socklen_t addrlen = test_addr_dual.addr_len; - ASSERT_THAT(getsockname(fd_dual.get(), - reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + ASSERT_THAT(getsockname(fd_dual.get(), AsSockAddr(&addr_dual), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); @@ -2227,8 +2304,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len), + ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Verify that we can still bind the v4 loopback on the same port. @@ -2238,9 +2314,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v4_mapped.family(), param.type, 0)); - int ret = - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len); + int ret = bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped), + test_addr_v4_mapped.addr_len); if (ret == -1 && errno == EADDRINUSE) { // Port may have been in use. ASSERT_LT(i, 100); // Give up after 100 tries. @@ -2262,9 +2337,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { sockaddr_storage bound_addr = test_addr.addr; const FileDescriptor bound_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); // Listen iff TCP. if (param.type == SOCK_STREAM) { @@ -2274,23 +2349,20 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); // Connect to bind an ephemeral port. const FileDescriptor connected_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), &connected_addr_len), SyscallSucceeds()); uint16_t const ephemeral_port = @@ -2302,10 +2374,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { // Verify that the ephemeral port is reserved. const FileDescriptor checking_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallFailsWithErrno(EADDRINUSE)); + EXPECT_THAT(bind(checking_fd.get(), AsSockAddr(&connected_addr), + connected_addr_len), + SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v6 loopback with the same port fails. TestAddress const& test_addr_v6 = V6Loopback(); @@ -2314,8 +2385,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len), + ASSERT_THAT(bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Verify that we can still bind the v4 loopback on the same port. @@ -2325,9 +2395,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { ephemeral_port)); const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v4_mapped.family(), param.type, 0)); - int ret = - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len); + int ret = bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped), + test_addr_v4_mapped.addr_len); if (ret == -1 && errno == EADDRINUSE) { // Port may have been in use. ASSERT_LT(i, 100); // Give up after 100 tries. @@ -2348,8 +2417,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { sockaddr_storage bound_addr = test_addr.addr; const FileDescriptor bound_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), + ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), SyscallSucceeds()); ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), @@ -2363,8 +2431,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); // Connect to bind an ephemeral port. @@ -2373,16 +2440,14 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), &connected_addr_len), SyscallSucceeds()); uint16_t const ephemeral_port = @@ -2398,8 +2463,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), + bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), SyscallSucceeds()); } @@ -2412,9 +2476,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { sockaddr_storage bound_addr = test_addr.addr; const FileDescriptor bound_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); // Listen iff TCP. if (param.type == SOCK_STREAM) { @@ -2424,23 +2488,20 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); // Connect to bind an ephemeral port. const FileDescriptor connected_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), &connected_addr_len), SyscallSucceeds()); uint16_t const ephemeral_port = @@ -2452,10 +2513,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { // Verify that the ephemeral port is reserved. const FileDescriptor checking_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallFailsWithErrno(EADDRINUSE)); + EXPECT_THAT(bind(checking_fd.get(), AsSockAddr(&connected_addr), + connected_addr_len), + SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 loopback on the same port with a v4 socket // fails. @@ -2465,8 +2525,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { SetAddrPort(test_addr_v4.family(), &addr_v4, ephemeral_port)); const FileDescriptor fd_v4 = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); - EXPECT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), - test_addr_v4.addr_len), + EXPECT_THAT(bind(fd_v4.get(), AsSockAddr(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v6 any on the same port with a dual-stack socket @@ -2477,7 +2536,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port)); const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), + ASSERT_THAT(bind(fd_v6_any.get(), AsSockAddr(&addr_v6_any), test_addr_v6_any.addr_len), SyscallFailsWithErrno(EADDRINUSE)); @@ -2496,8 +2555,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6.family(), param.type, 0)); - ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); + ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len); } else { // Verify that we can still bind the v6 any on the same port with a // v6-only socket. @@ -2506,9 +2564,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ret = - bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), - test_addr_v6_any.addr_len); + ret = bind(fd_v6_only_any.get(), AsSockAddr(&addr_v6_any), + test_addr_v6_any.addr_len); } if (ret == -1 && errno == EADDRINUSE) { @@ -2532,8 +2589,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, sockaddr_storage bound_addr = test_addr.addr; const FileDescriptor bound_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), + ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), SyscallSucceeds()); ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, @@ -2548,8 +2604,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); // Connect to bind an ephemeral port. @@ -2558,16 +2613,14 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), &connected_addr_len), SyscallSucceeds()); uint16_t const ephemeral_port = @@ -2583,8 +2636,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), + bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), SyscallSucceeds()); } @@ -2597,9 +2649,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { sockaddr_storage bound_addr = test_addr.addr; const FileDescriptor bound_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); // Listen iff TCP. if (param.type == SOCK_STREAM) { @@ -2609,23 +2661,20 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); // Connect to bind an ephemeral port. const FileDescriptor connected_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), &connected_addr_len), SyscallSucceeds()); uint16_t const ephemeral_port = @@ -2637,10 +2686,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { // Verify that the ephemeral port is reserved. const FileDescriptor checking_fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); - EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), - SyscallFailsWithErrno(EADDRINUSE)); + EXPECT_THAT(bind(checking_fd.get(), AsSockAddr(&connected_addr), + connected_addr_len), + SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v4 loopback on the same port with a v6 socket // fails. @@ -2650,10 +2698,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { ephemeral_port)); const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v4_mapped.family(), param.type, 0)); - EXPECT_THAT( - bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), - test_addr_v4_mapped.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); + EXPECT_THAT(bind(fd_v4_mapped.get(), AsSockAddr(&addr_v4_mapped), + test_addr_v4_mapped.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); // Verify that binding the v6 any on the same port with a dual-stack socket // fails. @@ -2663,7 +2710,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { SetAddrPort(test_addr_v6_any.family(), &addr_v6_any, ephemeral_port)); const FileDescriptor fd_v6_any = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v6_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), + ASSERT_THAT(bind(fd_v6_any.get(), AsSockAddr(&addr_v6_any), test_addr_v6_any.addr_len), SyscallFailsWithErrno(EADDRINUSE)); @@ -2682,8 +2729,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { SetAddrPort(test_addr_v6.family(), &addr_v6, ephemeral_port)); const FileDescriptor fd_v6 = ASSERT_NO_ERRNO_AND_VALUE( Socket(test_addr_v6.family(), param.type, 0)); - ret = bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), - test_addr_v6.addr_len); + ret = bind(fd_v6.get(), AsSockAddr(&addr_v6), test_addr_v6.addr_len); } else { // Verify that we can still bind the v6 any on the same port with a // v6-only socket. @@ -2692,9 +2738,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ret = - bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any), - test_addr_v6_any.addr_len); + ret = bind(fd_v6_only_any.get(), AsSockAddr(&addr_v6_any), + test_addr_v6_any.addr_len); } if (ret == -1 && errno == EADDRINUSE) { @@ -2722,8 +2767,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), + ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), SyscallSucceeds()); // Listen iff TCP. @@ -2734,8 +2778,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); // Connect to bind an ephemeral port. @@ -2746,16 +2789,14 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), - reinterpret_cast<sockaddr*>(&bound_addr), + ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(), AsSockAddr(&bound_addr), bound_addr_len), SyscallSucceeds()); // Get the ephemeral port. sockaddr_storage connected_addr = {}; socklen_t connected_addr_len = sizeof(connected_addr); - ASSERT_THAT(getsockname(connected_fd.get(), - reinterpret_cast<sockaddr*>(&connected_addr), + ASSERT_THAT(getsockname(connected_fd.get(), AsSockAddr(&connected_addr), &connected_addr_len), SyscallSucceeds()); uint16_t const ephemeral_port = @@ -2771,8 +2812,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); EXPECT_THAT( - bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr), - connected_addr_len), + bind(checking_fd.get(), AsSockAddr(&connected_addr), connected_addr_len), SyscallSucceeds()); } @@ -2791,14 +2831,12 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), + ASSERT_THAT(bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), SyscallSucceeds()); // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); // Now create a socket and bind it to the same port, this should @@ -2809,9 +2847,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, ASSERT_THAT(setsockopt(second_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(second_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(second_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len), + SyscallSucceeds()); } TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { @@ -2830,10 +2868,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { setsockopt(fd1, SOL_SOCKET, SO_REUSEPORT, &portreuse1, sizeof(int)), SyscallSucceeds()); - ASSERT_THAT(bind(fd1, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(fd1, AsSockAddr(&addr), addrlen), SyscallSucceeds()); - ASSERT_THAT(getsockname(fd1, reinterpret_cast<sockaddr*>(&addr), &addrlen), + ASSERT_THAT(getsockname(fd1, AsSockAddr(&addr), &addrlen), SyscallSucceeds()); if (param.type == SOCK_STREAM) { ASSERT_THAT(listen(fd1, 1), SyscallSucceeds()); @@ -2852,7 +2889,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { SyscallSucceeds()); std::cout << portreuse1 << " " << portreuse2 << std::endl; - int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen); + int ret = bind(fd2, AsSockAddr(&addr), addrlen); // Verify that two sockets can be bound to the same port only if // SO_REUSEPORT is set for both of them. @@ -2880,10 +2917,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) { ASSERT_THAT( setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)), SyscallSucceeds()); - ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); - ASSERT_THAT(getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(fd, AsSockAddr(&addr), addrlen), SyscallSucceeds()); + ASSERT_THAT(getsockname(fd, AsSockAddr(&addr), &addrlen), SyscallSucceeds()); ASSERT_EQ(addrlen, test_addr.addr_len); s.reset(); @@ -2895,8 +2930,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) { ASSERT_THAT( setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &portreuse, sizeof(portreuse)), SyscallSucceeds()); - ASSERT_THAT(bind(fd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(fd, AsSockAddr(&addr), addrlen), SyscallSucceeds()); } INSTANTIATE_TEST_SUITE_P( diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc index 1a0b53394..601ae107b 100644 --- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -86,7 +86,7 @@ using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; // We disable S/R because this test creates a large number of sockets. // // FIXME(b/162475855): This test is failing reliably. -TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -98,15 +98,14 @@ TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) { auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(listen_fd.get(), AsSockAddr(&listen_addr), listener.addr_len), + SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + ASSERT_THAT(getsockname(listen_fd.get(), AsSockAddr(&listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); @@ -124,8 +123,7 @@ TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) { for (int i = 0; i < kClients; i++) { FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); + int ret = connect(client.get(), AsSockAddr(&conn_addr), connector.addr_len); if (ret == 0) { clients.push_back(std::move(client)); FileDescriptor server = @@ -181,8 +179,7 @@ std::string DescribeProtocolTestParam( using SocketMultiProtocolInetLoopbackTest = ::testing::TestWithParam<ProtocolTestParam>; -TEST_P(SocketMultiProtocolInetLoopbackTest, - BindAvoidsListeningPortsReuseAddr_NoRandomSave) { +TEST_P(SocketMultiProtocolInetLoopbackTest, BindAvoidsListeningPortsReuseAddr) { const auto& param = GetParam(); // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP // this is only permitted if there is no other listening socket. @@ -205,8 +202,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - int ret = bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - test_addr.addr_len); + int ret = bind(bound_fd.get(), AsSockAddr(&bound_addr), test_addr.addr_len); if (ret != 0) { ASSERT_EQ(errno, EADDRINUSE); break; @@ -214,8 +210,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, // Get the port that we bound. socklen_t bound_addr_len = test_addr.addr_len; ASSERT_THAT( - getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), - &bound_addr_len), + getsockname(bound_fd.get(), AsSockAddr(&bound_addr), &bound_addr_len), SyscallSucceeds()); uint16_t port = reinterpret_cast<sockaddr_in*>(&bound_addr)->sin_port; diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index f10f55b27..59b56dc1a 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -1153,7 +1153,7 @@ TEST_P(TCPSocketPairTest, IpMulticastLoopDefault) { EXPECT_EQ(get, 1); } -TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { +TEST_P(TCPSocketPairTest, TCPResetDuringClose) { DisableSave ds; // Too many syscalls. constexpr int kThreadCount = 1000; std::unique_ptr<ScopedThread> instances[kThreadCount]; diff --git a/test/syscalls/linux/socket_ip_unbound_netlink.cc b/test/syscalls/linux/socket_ip_unbound_netlink.cc index 7fb1c0faf..b02222999 100644 --- a/test/syscalls/linux/socket_ip_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ip_unbound_netlink.cc @@ -35,7 +35,7 @@ namespace testing { // Test fixture for tests that apply to pairs of IP sockets. using IPv6UnboundSocketTest = SimpleSocketTest; -TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { +TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved @@ -57,8 +57,7 @@ TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { TestAddress addr = V6Loopback(); reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = 65535; auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + EXPECT_THAT(connect(sock->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EADDRNOTAVAIL)); } @@ -69,7 +68,7 @@ INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv6UnboundSocketTest, using IPv4UnboundSocketTest = SimpleSocketTest; -TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { +TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved @@ -90,8 +89,7 @@ TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { TestAddress addr = V4Loopback(); reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = 65535; auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + EXPECT_THAT(connect(sock->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallFailsWithErrno(ENETUNREACH)); } diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 8eec31a46..18be4dcc7 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -44,20 +44,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address. If multicast worked like unicast, // this would ensure that we get the packet. auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -68,10 +65,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -83,19 +80,19 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { // Check that not setting a default send interface prevents multicast packets // from being sent. Group membership interface configured by address. TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) { + // TODO(b/185517803): Fix for native test. + SKIP_IF(!IsRunningOnGvisor()); auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the second FD to the v4 any address to ensure that we can receive any // unicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -114,28 +111,28 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); + EXPECT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallFailsWithErrno(ENETUNREACH)); } // Check that not setting a default send interface prevents multicast packets // from being sent. Group membership interface configured by NIC ID. TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) { + // TODO(b/185517803): Fix for native test. + SKIP_IF(!IsRunningOnGvisor()); auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Bind the second FD to the v4 any address to ensure that we can receive any // unicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -154,10 +151,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallFailsWithErrno(ENETUNREACH)); + EXPECT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallFailsWithErrno(ENETUNREACH)); } // Check that multicast works when the default send interface is configured by @@ -170,20 +167,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -202,10 +196,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -226,20 +220,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -258,10 +249,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -289,13 +280,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -314,10 +303,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -345,13 +334,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -370,10 +357,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -401,13 +388,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -425,8 +410,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), + RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); @@ -461,13 +445,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -485,8 +467,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), + RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); @@ -521,13 +502,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -546,10 +525,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -577,13 +556,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -602,10 +579,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -633,13 +610,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -657,8 +632,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; EXPECT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), + RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); @@ -691,13 +665,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -715,8 +687,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; ASSERT_THAT( - RetryEINTR(connect)(socket1->get(), - reinterpret_cast<sockaddr*>(&connect_addr.addr), + RetryEINTR(connect)(socket1->get(), AsSockAddr(&connect_addr.addr), connect_addr.addr_len), SyscallSucceeds()); @@ -753,13 +724,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { // Bind the first FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -778,10 +747,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -813,13 +782,11 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -838,10 +805,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -877,20 +844,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -912,10 +876,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -935,20 +899,17 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); EXPECT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - EXPECT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + EXPECT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -970,10 +931,10 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -1194,6 +1155,8 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) { } TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) { + // TODO(b/185517803): Fix for native test. + SKIP_IF(!IsRunningOnGvisor()); auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -1292,16 +1255,15 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), SyscallSucceeds()); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(bind(sockets->second_fd(), AsSockAddr(&receiver_addr.addr), receiver_addr.addr_len), SyscallSucceeds()); // Get the port assigned. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); + ASSERT_THAT( + getsockname(sockets->second_fd(), AsSockAddr(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); // On the first iteration, save the port we are bound to. On the second // iteration, verify the port is the same as the one from the first @@ -1324,8 +1286,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { RandomizeBuffer(send_buf, sizeof(send_buf)); ASSERT_THAT( RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), + AsSockAddr(&send_addr.addr), send_addr.addr_len), SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet on both sockets. @@ -1367,16 +1328,15 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), SyscallSucceeds()); - ASSERT_THAT(bind(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(bind(sockets->second_fd(), AsSockAddr(&receiver_addr.addr), receiver_addr.addr_len), SyscallSucceeds()); // Get the port assigned. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(sockets->second_fd(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), - SyscallSucceeds()); + ASSERT_THAT( + getsockname(sockets->second_fd(), AsSockAddr(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); // On the first iteration, save the port we are bound to. On the second // iteration, verify the port is the same as the one from the first @@ -1403,8 +1363,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { RandomizeBuffer(send_buf, sizeof(send_buf)); ASSERT_THAT( RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), + AsSockAddr(&send_addr.addr), send_addr.addr_len), SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet on both sockets. @@ -1427,8 +1386,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { char send_buf[200]; ASSERT_THAT( RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), + AsSockAddr(&send_addr.addr), send_addr.addr_len), SyscallSucceedsWithValue(sizeof(send_buf))); char recv_buf[sizeof(send_buf)] = {}; @@ -1448,14 +1406,12 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { // Bind second socket (receiver) to the multicast address. auto receiver_addr = V4Multicast(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Update receiver_addr with the correct port number. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -1479,10 +1435,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -1500,14 +1456,12 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { // Bind second socket (receiver) to the multicast address. auto receiver_addr = V4Multicast(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Update receiver_addr with the correct port number. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -1523,10 +1477,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we don't receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -1543,13 +1497,11 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { // Bind second socket (receiver) to the ANY address. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -1557,12 +1509,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { // Bind the first socket (sender) to the multicast address. auto sender_addr = V4Multicast(); ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&sender_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&sender_addr.addr), &sender_addr_len), SyscallSucceeds()); EXPECT_EQ(sender_addr_len, sender_addr.addr_len); @@ -1573,10 +1523,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; @@ -1594,13 +1544,11 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { // Bind second socket (receiver) to the broadcast address. auto receiver_addr = V4Broadcast(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -1611,19 +1559,18 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { SyscallSucceedsWithValue(0)); // Note: Binding to the loopback interface makes the broadcast go out of it. auto sender_bind_addr = V4Loopback(); - ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_bind_addr.addr), - sender_bind_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&sender_bind_addr.addr), + sender_bind_addr.addr_len), + SyscallSucceeds()); auto sendto_addr = V4Broadcast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -1641,13 +1588,11 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { // Bind second socket (receiver) to the ANY address. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(socket2->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(socket2->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -1655,12 +1600,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { // Bind the first socket (sender) to the broadcast address. auto sender_addr = V4Broadcast(); ASSERT_THAT( - bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(socket1->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&sender_addr.addr), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&sender_addr.addr), &sender_addr_len), SyscallSucceeds()); EXPECT_EQ(sender_addr_len, sender_addr.addr_len); @@ -1671,10 +1614,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; @@ -1688,7 +1631,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { // // FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. Enable // random and co-op save (below) once that is fixed. -TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { +TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) { std::vector<std::unique_ptr<FileDescriptor>> sockets; sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); @@ -1698,12 +1641,10 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(sockets[0]->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(sockets[0]->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(sockets[0]->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(sockets[0]->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -1719,8 +1660,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { ASSERT_THAT(setsockopt(last->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(last->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(last->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); // Send a new message to the SO_REUSEADDR group. We use a new socket each @@ -1730,8 +1670,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { char send_buf[kMessageSize]; RandomizeBuffer(send_buf, sizeof(send_buf)); EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceedsWithValue(sizeof(send_buf))); // Verify that the most recent socket got the message. We don't expect any @@ -1763,12 +1702,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -1776,8 +1713,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) { ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -1792,12 +1728,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -1805,8 +1739,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) { ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -1825,12 +1758,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -1838,16 +1769,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) { ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); // Bind socket3 to the same address as socket1, only with REUSEADDR. ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -1866,12 +1795,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -1879,16 +1806,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) { ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); // Bind socket3 to the same address as socket1, only with REUSEPORT. ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EADDRINUSE)); } @@ -1907,12 +1832,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -1920,8 +1843,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. @@ -1931,8 +1853,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); } @@ -1951,12 +1872,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -1964,8 +1883,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) { ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT. @@ -1975,8 +1893,7 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) { ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); } @@ -1995,12 +1912,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -2013,16 +1928,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) { sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); // Bind socket3 to the same address as socket1, only with REUSEPORT. ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); } @@ -2041,12 +1954,10 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(socket1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(socket1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -2059,16 +1970,14 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) { sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); // Bind socket3 to the same address as socket1, only with REUSEADDR. ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(socket3->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); } @@ -2086,12 +1995,10 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(receiver1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(receiver1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(receiver1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -2103,8 +2010,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(bind(receiver2->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(receiver2->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); constexpr int kMessageSize = 10; @@ -2119,8 +2025,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); char send_buf[kMessageSize] = {}; EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceedsWithValue(sizeof(send_buf))); } @@ -2149,13 +2054,11 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { int level = SOL_IP; int type = IP_PKTINFO; - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); socklen_t sender_addr_len = sender_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&sender_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&sender_addr.addr), &sender_addr_len), SyscallSucceeds()); EXPECT_EQ(sender_addr_len, sender_addr.addr_len); @@ -2163,10 +2066,9 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { auto receiver_addr = V4Loopback(); reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port; - ASSERT_THAT( - connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(connect(sender->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Allow socket to receive control message. ASSERT_THAT( @@ -2230,29 +2132,25 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) { int level = SOL_IP; int type = IP_RECVORIGDSTADDR; - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Retrieve the port bound by the receiver. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - ASSERT_THAT( - connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(connect(sender->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Get address and port bound by the sender. sockaddr_storage sender_addr_storage; socklen_t sender_addr_len = sizeof(sender_addr_storage); - ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<sockaddr*>(&sender_addr_storage), + ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&sender_addr_storage), &sender_addr_len), SyscallSucceeds()); ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in)); @@ -2407,9 +2305,7 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBuf) { SyscallSucceeds()); // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. - if (!IsRunningOnGvisor()) { - quarter_sz *= 2; - } + quarter_sz *= 2; ASSERT_EQ(quarter_sz, val); } @@ -2524,22 +2420,19 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) { // Bind the first FD to the loopback. This is an alternative to // IP_MULTICAST_IF for setting the default send interface. auto sender_addr = V4Loopback(); - ASSERT_THAT( - bind(sender_socket->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(sender_socket->get(), AsSockAddr(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT(bind(receiver_socket->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(bind(receiver_socket->get(), AsSockAddr(&receiver_addr.addr), receiver_addr.addr_len), SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; ASSERT_THAT(getsockname(receiver_socket->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - &receiver_addr_len), + AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -2565,8 +2458,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) { RandomizeBuffer(send_buf, sizeof(send_buf)); ASSERT_THAT( RetryEINTR(sendto)(sender_socket->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), + AsSockAddr(&send_addr.addr), send_addr.addr_len), SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index 940289d15..c6e775b2a 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -50,38 +50,35 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the first socket to the ANY address and let the system assign a port. auto rcv1_addr = V4Any(); - ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - rcv1_addr.addr_len), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + bind(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), rcv1_addr.addr_len), + SyscallSucceedsWithValue(0)); // Retrieve port number from first socket so that it can be bound to the // second socket. socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - &rcv_addr_sz), + getsockname(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), &rcv_addr_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; // Bind the second socket to the same address:port as the first. - ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - rcv_addr_sz), + ASSERT_THAT(bind(rcvr2->get(), AsSockAddr(&rcv1_addr.addr), rcv_addr_sz), SyscallSucceedsWithValue(0)); // Bind the non-receiving socket to an ephemeral port. auto norecv_addr = V4Any(); - ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), - norecv_addr.addr_len), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + bind(norcv->get(), AsSockAddr(&norecv_addr.addr), norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); // Broadcast a test message. auto dst_addr = V4Broadcast(); reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + AsSockAddr(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); // Verify that the receiving sockets received the test message. char buf[sizeof(kTestMsg)] = {}; @@ -130,15 +127,14 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the first socket the ANY address and let the system assign a port. auto rcv1_addr = V4Any(); - ASSERT_THAT(bind(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - rcv1_addr.addr_len), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + bind(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), rcv1_addr.addr_len), + SyscallSucceedsWithValue(0)); // Retrieve port number from first socket so that it can be bound to the // second socket. socklen_t rcv_addr_sz = rcv1_addr.addr_len; ASSERT_THAT( - getsockname(rcvr1->get(), reinterpret_cast<sockaddr*>(&rcv1_addr.addr), - &rcv_addr_sz), + getsockname(rcvr1->get(), AsSockAddr(&rcv1_addr.addr), &rcv_addr_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(rcv_addr_sz, rcv1_addr.addr_len); auto port = reinterpret_cast<sockaddr_in*>(&rcv1_addr.addr)->sin_port; @@ -146,26 +142,25 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the second socket to the broadcast address. auto rcv2_addr = V4Broadcast(); reinterpret_cast<sockaddr_in*>(&rcv2_addr.addr)->sin_port = port; - ASSERT_THAT(bind(rcvr2->get(), reinterpret_cast<sockaddr*>(&rcv2_addr.addr), - rcv2_addr.addr_len), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + bind(rcvr2->get(), AsSockAddr(&rcv2_addr.addr), rcv2_addr.addr_len), + SyscallSucceedsWithValue(0)); // Bind the non-receiving socket to the unicast ethernet address. auto norecv_addr = rcv1_addr; reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr = eth_if_addr_.sin_addr; - ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr), - norecv_addr.addr_len), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + bind(norcv->get(), AsSockAddr(&norecv_addr.addr), norecv_addr.addr_len), + SyscallSucceedsWithValue(0)); // Broadcast a test message. auto dst_addr = V4Broadcast(); reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = port; constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + AsSockAddr(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); // Verify that the receiving sockets received the test message. char buf[sizeof(kTestMsg)] = {}; @@ -199,12 +194,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the sender to the broadcast address. auto src_addr = V4Broadcast(); - ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), - src_addr.addr_len), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + bind(sender->get(), AsSockAddr(&src_addr.addr), src_addr.addr_len), + SyscallSucceedsWithValue(0)); socklen_t src_sz = src_addr.addr_len; - ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), + ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&src_addr.addr), &src_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(src_sz, src_addr.addr_len); @@ -213,10 +207,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + AsSockAddr(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); // Verify that the message was received. char buf[sizeof(kTestMsg)] = {}; @@ -241,12 +234,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the sender to the ANY address. auto src_addr = V4Any(); - ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&src_addr.addr), - src_addr.addr_len), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + bind(sender->get(), AsSockAddr(&src_addr.addr), src_addr.addr_len), + SyscallSucceedsWithValue(0)); socklen_t src_sz = src_addr.addr_len; - ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<sockaddr*>(&src_addr.addr), &src_sz), + ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&src_addr.addr), &src_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(src_sz, src_addr.addr_len); @@ -255,10 +247,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&dst_addr.addr)->sin_port = reinterpret_cast<sockaddr_in*>(&src_addr.addr)->sin_port; constexpr char kTestMsg[] = "hello, world"; - EXPECT_THAT( - sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&dst_addr.addr), dst_addr.addr_len), - SyscallSucceedsWithValue(sizeof(kTestMsg))); + EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, + AsSockAddr(&dst_addr.addr), dst_addr.addr_len), + SyscallSucceedsWithValue(sizeof(kTestMsg))); // Verify that the message was received. char buf[sizeof(kTestMsg)] = {}; @@ -280,7 +271,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { constexpr char kTestMsg[] = "hello, world"; EXPECT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len), + AsSockAddr(&addr.addr), addr.addr_len), SyscallFailsWithErrno(EACCES)); } @@ -294,19 +285,17 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendUnicastOnUnbound) { addr.sin_family = AF_INET; addr.sin_addr.s_addr = htonl(INADDR_ANY); addr.sin_port = htons(0); - ASSERT_THAT(bind(rcvr->get(), reinterpret_cast<struct sockaddr*>(&addr), - sizeof(addr)), + ASSERT_THAT(bind(rcvr->get(), AsSockAddr(&addr), sizeof(addr)), SyscallSucceedsWithValue(0)); memset(&addr, 0, sizeof(addr)); socklen_t addr_sz = sizeof(addr); - ASSERT_THAT(getsockname(rcvr->get(), - reinterpret_cast<struct sockaddr*>(&addr), &addr_sz), + ASSERT_THAT(getsockname(rcvr->get(), AsSockAddr(&addr), &addr_sz), SyscallSucceedsWithValue(0)); // Send a test message to the receiver. constexpr char kTestMsg[] = "hello, world"; ASSERT_THAT(sendto(sender->get(), kTestMsg, sizeof(kTestMsg), 0, - reinterpret_cast<struct sockaddr*>(&addr), addr_sz), + AsSockAddr(&addr), addr_sz), SyscallSucceedsWithValue(sizeof(kTestMsg))); char buf[sizeof(kTestMsg)] = {}; ASSERT_THAT(recv(rcvr->get(), buf, sizeof(buf), 0), @@ -326,13 +315,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); - ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - bind_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket->get(), AsSockAddr(&bind_addr.addr), bind_addr.addr_len), + SyscallSucceeds()); socklen_t bind_addr_len = bind_addr.addr_len; ASSERT_THAT( - getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - &bind_addr_len), + getsockname(socket->get(), AsSockAddr(&bind_addr.addr), &bind_addr_len), SyscallSucceeds()); EXPECT_EQ(bind_addr_len, bind_addr.addr_len); @@ -342,10 +330,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -361,13 +349,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); - ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - bind_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket->get(), AsSockAddr(&bind_addr.addr), bind_addr.addr_len), + SyscallSucceeds()); socklen_t bind_addr_len = bind_addr.addr_len; ASSERT_THAT( - getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - &bind_addr_len), + getsockname(socket->get(), AsSockAddr(&bind_addr.addr), &bind_addr_len), SyscallSucceeds()); EXPECT_EQ(bind_addr_len, bind_addr.addr_len); @@ -384,10 +371,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -405,13 +392,12 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); - ASSERT_THAT(bind(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - bind_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(socket->get(), AsSockAddr(&bind_addr.addr), bind_addr.addr_len), + SyscallSucceeds()); socklen_t bind_addr_len = bind_addr.addr_len; ASSERT_THAT( - getsockname(socket->get(), reinterpret_cast<sockaddr*>(&bind_addr.addr), - &bind_addr_len), + getsockname(socket->get(), AsSockAddr(&bind_addr.addr), &bind_addr_len), SyscallSucceeds()); EXPECT_EQ(bind_addr_len, bind_addr.addr_len); @@ -433,10 +419,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&bind_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(socket->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -460,13 +446,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -477,10 +461,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -499,13 +483,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -523,10 +505,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -547,13 +529,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -576,10 +556,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -600,13 +580,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Bind the second FD to the v4 any address to ensure that we can receive the // multicast packet. auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -629,10 +607,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -661,13 +639,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); // Bind to ANY to receive multicast packets. - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -696,10 +672,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); for (auto& receiver : receivers) { char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( @@ -727,13 +703,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -765,10 +739,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); for (auto& receiver : receivers) { char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( @@ -798,13 +772,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, ASSERT_THAT(setsockopt(receiver->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -840,10 +812,10 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = bound_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); for (auto& receiver : receivers) { char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( @@ -863,13 +835,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -887,15 +857,13 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // receiver side). auto sendto_addr = V4Multicast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port; - ASSERT_THAT(RetryEINTR(connect)( - sender->get(), reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), + ASSERT_THAT(RetryEINTR(connect)(sender->get(), AsSockAddr(&sendto_addr.addr), + sendto_addr.addr_len), SyscallSucceeds()); auto sender_addr = V4EmptyAddress(); - ASSERT_THAT( - getsockname(sender->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - &sender_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&sender_addr.addr), + &sender_addr.addr_len), + SyscallSucceeds()); ASSERT_EQ(sizeof(struct sockaddr_in), sender_addr.addr_len); sockaddr_in* sender_addr_in = reinterpret_cast<sockaddr_in*>(&sender_addr.addr); @@ -910,8 +878,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto src_addr = V4EmptyAddress(); ASSERT_THAT( RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0, - reinterpret_cast<sockaddr*>(&src_addr.addr), - &src_addr.addr_len), + AsSockAddr(&src_addr.addr), &src_addr.addr_len), SyscallSucceedsWithValue(sizeof(recv_buf))); ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len); sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr); @@ -931,13 +898,11 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Create receiver, bind to ANY and join the multicast group. auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver_addr = V4Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -964,18 +929,17 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, auto sendto_addr = V4Multicast(); reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port = receiver_port; char send_buf[4] = {}; - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr.addr), - sendto_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&sendto_addr.addr), sendto_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Receive a multicast packet. char recv_buf[sizeof(send_buf)] = {}; auto src_addr = V4EmptyAddress(); ASSERT_THAT( RetryEINTR(recvfrom)(receiver->get(), recv_buf, sizeof(recv_buf), 0, - reinterpret_cast<sockaddr*>(&src_addr.addr), - &src_addr.addr_len), + AsSockAddr(&src_addr.addr), &src_addr.addr_len), SyscallSucceedsWithValue(sizeof(recv_buf))); ASSERT_EQ(sizeof(struct sockaddr_in), src_addr.addr_len); sockaddr_in* src_addr_in = reinterpret_cast<sockaddr_in*>(&src_addr.addr); @@ -1000,9 +964,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Create sender and bind to eth interface. auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(ð_if_addr_), - sizeof(eth_if_addr_)), - SyscallSucceeds()); + ASSERT_THAT( + bind(sender->get(), AsSockAddr(ð_if_addr_), sizeof(eth_if_addr_)), + SyscallSucceeds()); // Run through all possible combinations of index and address for // IP_MULTICAST_IF that selects the loopback interface. diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc index bcbd2feac..7ca6d52e4 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc @@ -29,18 +29,15 @@ using IPv4UDPUnboundSocketNogotsanTest = SimpleSocketTest; // Check that connect returns EAGAIN when out of local ephemeral ports. // We disable S/R because this test creates a large number of sockets. -TEST_P(IPv4UDPUnboundSocketNogotsanTest, - UDPConnectPortExhaustion_NoRandomSave) { +TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPConnectPortExhaustion) { auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); constexpr int kClients = 65536; // Bind the first socket to the loopback and take note of the selected port. auto addr = V4Loopback(); - ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), + ASSERT_THAT(bind(receiver1->get(), AsSockAddr(&addr.addr), addr.addr_len), SyscallSucceeds()); socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(receiver1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + ASSERT_THAT(getsockname(receiver1->get(), AsSockAddr(&addr.addr), &addr_len), SyscallSucceeds()); EXPECT_EQ(addr_len, addr.addr_len); @@ -50,8 +47,7 @@ TEST_P(IPv4UDPUnboundSocketNogotsanTest, for (int i = 0; i < kClients; i++) { auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len); + int ret = connect(s->get(), AsSockAddr(&addr.addr), addr.addr_len); if (ret == 0) { sockets.push_back(std::move(s)); continue; @@ -63,7 +59,7 @@ TEST_P(IPv4UDPUnboundSocketNogotsanTest, // Check that bind returns EADDRINUSE when out of local ephemeral ports. // We disable S/R because this test creates a large number of sockets. -TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion_NoRandomSave) { +TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion) { auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); constexpr int kClients = 65536; auto addr = V4Loopback(); @@ -73,8 +69,7 @@ TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion_NoRandomSave) { for (int i = 0; i < kClients; i++) { auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - int ret = - bind(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len); + int ret = bind(s->get(), AsSockAddr(&addr.addr), addr.addr_len); if (ret == 0) { sockets.push_back(std::move(s)); continue; diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc index 9a9ddc297..020ce5d6e 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -56,10 +56,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.2", &(reinterpret_cast<sockaddr_in*>(&sender_addr.addr) ->sin_addr.s_addr))); - ASSERT_THAT( - bind(snd_sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(snd_sock->get(), AsSockAddr(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); // Send the packet to an unassigned address but an address that is in the // subnet associated with the loopback interface. @@ -69,23 +68,20 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.254", &(reinterpret_cast<sockaddr_in*>(&receiver_addr.addr) ->sin_addr.s_addr))); - ASSERT_THAT( - bind(rcv_sock->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(rcv_sock->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(rcv_sock->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(rcv_sock->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len); char send_buf[kSendBufSize]; RandomizeBuffer(send_buf, kSendBufSize); - ASSERT_THAT( - RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0, - reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceedsWithValue(kSendBufSize)); + ASSERT_THAT(RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0, + AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceedsWithValue(kSendBufSize)); // Check that we received the packet. char recv_buf[kSendBufSize] = {}; @@ -155,14 +151,12 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { << "socks[" << idx << "]"; if (bind_wildcard) { - ASSERT_THAT( - bind(sock->get(), reinterpret_cast<sockaddr*>(&any_address.addr), - any_address.addr_len), - SyscallSucceeds()) + ASSERT_THAT(bind(sock->get(), AsSockAddr(&any_address.addr), + any_address.addr_len), + SyscallSucceeds()) << "socks[" << idx << "]"; } else { - ASSERT_THAT(bind(sock->get(), - reinterpret_cast<sockaddr*>(&broadcast_address.addr), + ASSERT_THAT(bind(sock->get(), AsSockAddr(&broadcast_address.addr), broadcast_address.addr_len), SyscallSucceeds()) << "socks[" << idx << "]"; @@ -177,17 +171,16 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { // Broadcasts from each socket should be received by every socket (including // the sending socket). - for (long unsigned int w = 0; w < socks.size(); w++) { + for (size_t w = 0; w < socks.size(); w++) { auto& w_sock = socks[w]; - ASSERT_THAT( - RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0, - reinterpret_cast<sockaddr*>(&broadcast_address.addr), - broadcast_address.addr_len), - SyscallSucceedsWithValue(kSendBufSize)) + ASSERT_THAT(RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0, + AsSockAddr(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceedsWithValue(kSendBufSize)) << "write socks[" << w << "]"; // Check that we received the packet on all sockets. - for (long unsigned int r = 0; r < socks.size(); r++) { + for (size_t r = 0; r < socks.size(); r++) { auto& r_sock = socks[r]; struct pollfd poll_fd = {r_sock->get(), POLLIN, 0}; diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.cc b/test/syscalls/linux/socket_ipv6_udp_unbound.cc index 08526468e..a4e3371f4 100644 --- a/test/syscalls/linux/socket_ipv6_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv6_udp_unbound.cc @@ -47,29 +47,25 @@ TEST_P(IPv6UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) { int level = SOL_IPV6; int type = IPV6_RECVORIGDSTADDR; - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Retrieve the port bound by the receiver. socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); - ASSERT_THAT( - connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(connect(sender->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); // Get address and port bound by the sender. sockaddr_storage sender_addr_storage; socklen_t sender_addr_len = sizeof(sender_addr_storage); - ASSERT_THAT(getsockname(sender->get(), - reinterpret_cast<sockaddr*>(&sender_addr_storage), + ASSERT_THAT(getsockname(sender->get(), AsSockAddr(&sender_addr_storage), &sender_addr_len), SyscallSucceeds()); ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in6)); diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc index 7364a1ea5..8390f7c3b 100644 --- a/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc @@ -24,13 +24,11 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) { auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver_addr = V6Any(); - ASSERT_THAT( - bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), - receiver_addr.addr_len), - SyscallSucceeds()); + ASSERT_THAT(bind(receiver->get(), AsSockAddr(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); socklen_t receiver_addr_len = receiver_addr.addr_len; - ASSERT_THAT(getsockname(receiver->get(), - reinterpret_cast<sockaddr*>(&receiver_addr.addr), + ASSERT_THAT(getsockname(receiver->get(), AsSockAddr(&receiver_addr.addr), &receiver_addr_len), SyscallSucceeds()); EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); @@ -50,8 +48,7 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) { // Set the sender to the loopback interface. auto sender_addr = V6Loopback(); ASSERT_THAT( - bind(sender->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), + bind(sender->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), SyscallSucceeds()); // Send a multicast packet. @@ -60,10 +57,10 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) { reinterpret_cast<sockaddr_in6*>(&receiver_addr.addr)->sin6_port; char send_buf[200]; RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; @@ -77,10 +74,10 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) { &group_req, sizeof(group_req)), SyscallSucceeds()); RandomizeBuffer(send_buf, sizeof(send_buf)); - ASSERT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, - reinterpret_cast<sockaddr*>(&send_addr.addr), - send_addr.addr_len), - SyscallSucceedsWithValue(sizeof(send_buf))); + ASSERT_THAT( + RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0, + AsSockAddr(&send_addr.addr), send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); ASSERT_THAT(RetryEINTR(recv)(receiver->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc index 2ee218231..48aace78a 100644 --- a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc @@ -44,9 +44,9 @@ TEST_P(IPv6UDPUnboundSocketNetlinkTest, JoinSubnet) { reinterpret_cast<sockaddr_in6*>(&sender_addr.addr) ->sin6_addr.s6_addr)); auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - EXPECT_THAT(bind(sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), - sender_addr.addr_len), - SyscallFailsWithErrno(EADDRNOTAVAIL)); + EXPECT_THAT( + bind(sock->get(), AsSockAddr(&sender_addr.addr), sender_addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); } } // namespace testing diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc index 538ee2268..0743322ac 100644 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ b/test/syscalls/linux/socket_stream_blocking.cc @@ -68,7 +68,7 @@ TEST_P(BlockingStreamSocketPairTest, BlockPartialWriteClosed) { // Random save may interrupt the call to sendmsg() in SendLargeSendMsg(), // causing the write to be incomplete and the test to hang. -TEST_P(BlockingStreamSocketPairTest, SendMsgTooLarge_NoRandomSave) { +TEST_P(BlockingStreamSocketPairTest, SendMsgTooLarge) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int sndbuf; @@ -102,7 +102,7 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) { // Test that MSG_WAITALL causes recv to block until all requested data is // received. Random save can interrupt blocking and cause received data to be // returned, even if the amount received is less than the full requested amount. -TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll_NoRandomSave) { +TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[100]; diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index b2a96086c..9e3a129cf 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -82,8 +82,7 @@ Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain, RETURN_ERROR_IF_SYSCALL_FAIL(bound = socket(domain, type, protocol)); MaybeSave(); // Successful socket creation. RETURN_ERROR_IF_SYSCALL_FAIL( - bind(bound, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr))); + bind(bound, AsSockAddr(&bind_addr), sizeof(bind_addr))); MaybeSave(); // Successful bind. RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5)); MaybeSave(); // Successful listen. @@ -92,8 +91,7 @@ Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain, RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol)); MaybeSave(); // Successful socket creation. RETURN_ERROR_IF_SYSCALL_FAIL( - connect(connected, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr))); + connect(connected, AsSockAddr(&bind_addr), sizeof(bind_addr))); MaybeSave(); // Successful connect. int accepted; @@ -145,22 +143,22 @@ Creator<SocketPair> BidirectionalBindSocketPairCreator(bool abstract, RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol)); MaybeSave(); // Successful socket creation. RETURN_ERROR_IF_SYSCALL_FAIL( - bind(sock1, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1))); + bind(sock1, AsSockAddr(&addr1), sizeof(addr1))); MaybeSave(); // Successful bind. int sock2; RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol)); MaybeSave(); // Successful socket creation. RETURN_ERROR_IF_SYSCALL_FAIL( - bind(sock2, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2))); + bind(sock2, AsSockAddr(&addr2), sizeof(addr2))); MaybeSave(); // Successful bind. - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock1, reinterpret_cast<struct sockaddr*>(&addr2), sizeof(addr2))); + RETURN_ERROR_IF_SYSCALL_FAIL( + connect(sock1, AsSockAddr(&addr2), sizeof(addr2))); MaybeSave(); // Successful connect. - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock2, reinterpret_cast<struct sockaddr*>(&addr1), sizeof(addr1))); + RETURN_ERROR_IF_SYSCALL_FAIL( + connect(sock2, AsSockAddr(&addr1), sizeof(addr1))); MaybeSave(); // Successful connect. // Cleanup no longer needed resources. @@ -206,15 +204,15 @@ Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type, int sock1; RETURN_ERROR_IF_SYSCALL_FAIL(sock1 = socket(domain, type, protocol)); MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock1, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); + RETURN_ERROR_IF_SYSCALL_FAIL( + connect(sock1, AsSockAddr(&addr), sizeof(addr))); MaybeSave(); // Successful connect. int sock2; RETURN_ERROR_IF_SYSCALL_FAIL(sock2 = socket(domain, type, protocol)); MaybeSave(); // Successful socket creation. - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock2, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); + RETURN_ERROR_IF_SYSCALL_FAIL( + connect(sock2, AsSockAddr(&addr), sizeof(addr))); MaybeSave(); // Successful connect. // Make and close another socketpair to ensure that the duped ends of the @@ -228,8 +226,8 @@ Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type, for (int i = 0; i < 2; i++) { int sock; RETURN_ERROR_IF_SYSCALL_FAIL(sock = socket(domain, type, protocol)); - RETURN_ERROR_IF_SYSCALL_FAIL(connect( - sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); + RETURN_ERROR_IF_SYSCALL_FAIL( + connect(sock, AsSockAddr(&addr), sizeof(addr))); RETURN_ERROR_IF_SYSCALL_FAIL(close(sock)); } @@ -308,11 +306,9 @@ template <typename T> PosixErrorOr<T> BindIP(int fd, bool dual_stack) { T addr = {}; LocalhostAddr(&addr, dual_stack); - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))); + RETURN_ERROR_IF_SYSCALL_FAIL(bind(fd, AsSockAddr(&addr), sizeof(addr))); socklen_t addrlen = sizeof(addr); - RETURN_ERROR_IF_SYSCALL_FAIL( - getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen)); + RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(fd, AsSockAddr(&addr), &addrlen)); return addr; } @@ -329,9 +325,8 @@ CreateTCPConnectAcceptSocketPair(int bound, int connected, int type, bool dual_stack, T bind_addr) { int connect_result = 0; RETURN_ERROR_IF_SYSCALL_FAIL( - (connect_result = RetryEINTR(connect)( - connected, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr))) == -1 && + (connect_result = RetryEINTR(connect)(connected, AsSockAddr(&bind_addr), + sizeof(bind_addr))) == -1 && errno == EINPROGRESS ? 0 : connect_result); @@ -703,7 +698,7 @@ PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, } RETURN_ERROR_IF_SYSCALL_FAIL( - bind(fd.get(), reinterpret_cast<sockaddr*>(&storage), storage_size)); + bind(fd.get(), AsSockAddr(&storage), storage_size)); // If the user specified 0 as the port, we will return the port that the // kernel gave us, otherwise we will validate that this socket bound to the @@ -711,8 +706,7 @@ PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, sockaddr_storage bound_storage = {}; socklen_t bound_storage_size = sizeof(bound_storage); RETURN_ERROR_IF_SYSCALL_FAIL( - getsockname(fd.get(), reinterpret_cast<sockaddr*>(&bound_storage), - &bound_storage_size)); + getsockname(fd.get(), AsSockAddr(&bound_storage), &bound_storage_size)); int available_port = -1; if (bound_storage.ss_family == AF_INET) { diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index b3ab286b8..f7ba90130 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -520,6 +520,20 @@ uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr, uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload, ssize_t payload_len); +// Convenient functions for reinterpreting common types to sockaddr pointer. +inline sockaddr* AsSockAddr(sockaddr_storage* s) { + return reinterpret_cast<sockaddr*>(s); +} +inline sockaddr* AsSockAddr(sockaddr_in* s) { + return reinterpret_cast<sockaddr*>(s); +} +inline sockaddr* AsSockAddr(sockaddr_in6* s) { + return reinterpret_cast<sockaddr*>(s); +} +inline sockaddr* AsSockAddr(sockaddr_un* s) { + return reinterpret_cast<sockaddr*>(s); +} + namespace internal { PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family, SocketType type, bool reuse_addr); diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index 884319e1d..9425e87a6 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -239,7 +239,7 @@ TEST_P(UnixNonStreamSocketPairTest, SendTimeout) { SyscallSucceeds()); // The buffer size should be big enough to avoid many iterations in the next - // loop. Otherwise, this will slow down cooperative_save tests. + // loop. Otherwise, this will slow down save tests. std::vector<char> buf(kPageSize); for (;;) { int ret; diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc index 8b1762000..dd3d25450 100644 --- a/test/syscalls/linux/socket_unix_unbound_abstract.cc +++ b/test/syscalls/linux/socket_unix_unbound_abstract.cc @@ -72,6 +72,52 @@ TEST_P(UnboundAbstractUnixSocketPairTest, BindNothing) { SyscallSucceeds()); } +TEST_P(UnboundAbstractUnixSocketPairTest, ListenZeroBacklog) { + SKIP_IF((GetParam().type & SOCK_DGRAM) != 0); + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + struct sockaddr_un addr = {}; + addr.sun_family = AF_UNIX; + constexpr char kPath[] = "\x00/foo_bar"; + memcpy(addr.sun_path, kPath, sizeof(kPath)); + ASSERT_THAT(bind(sockets->first_fd(), + reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceeds()); + ASSERT_THAT(listen(sockets->first_fd(), 0 /* backlog */), SyscallSucceeds()); + ASSERT_THAT(connect(sockets->second_fd(), + reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceeds()); + auto sockets2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + { + // Set the FD to O_NONBLOCK. + int opts; + int orig_opts; + ASSERT_THAT(opts = fcntl(sockets2->first_fd(), F_GETFL), SyscallSucceeds()); + orig_opts = opts; + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sockets2->first_fd(), F_SETFL, opts), SyscallSucceeds()); + + ASSERT_THAT( + connect(sockets2->first_fd(), reinterpret_cast<struct sockaddr*>(&addr), + sizeof(addr)), + SyscallFailsWithErrno(EAGAIN)); + } + { + // Set the FD to O_NONBLOCK. + int opts; + int orig_opts; + ASSERT_THAT(opts = fcntl(sockets2->second_fd(), F_GETFL), + SyscallSucceeds()); + orig_opts = opts; + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(sockets2->second_fd(), F_SETFL, opts), SyscallSucceeds()); + + ASSERT_THAT( + connect(sockets2->second_fd(), + reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallFailsWithErrno(EAGAIN)); + } +} + TEST_P(UnboundAbstractUnixSocketPairTest, GetSockNameFullLength) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index e5730a606..c85f6da0b 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -883,7 +883,7 @@ TEST(SpliceTest, FromPipeToDevZero) { static volatile int signaled = 0; void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } -TEST(SpliceTest, ToPipeWithSmallCapacityDoesNotSpin_NoRandomSave) { +TEST(SpliceTest, ToPipeWithSmallCapacityDoesNotSpin) { // Writes to a pipe that are less than PIPE_BUF must be atomic. This test // creates a pipe with only 128 bytes of capacity (< PIPE_BUF) and checks that // splicing to the pipe does not spin. See b/170743336. diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc index 4afed6d08..5a2841899 100644 --- a/test/syscalls/linux/sticky.cc +++ b/test/syscalls/linux/sticky.cc @@ -56,9 +56,7 @@ TEST(StickyTest, StickyBitPermDenied) { // thread won't be able to open some log files after the test ends. ScopedThread([&] { // Drop privileges. - if (HaveCapability(CAP_FOWNER).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false)); - } + AutoCapability cap(CAP_FOWNER, false); // Change EUID and EGID. EXPECT_THAT( @@ -98,9 +96,7 @@ TEST(StickyTest, StickyBitSameUID) { // thread won't be able to open some log files after the test ends. ScopedThread([&] { // Drop privileges. - if (HaveCapability(CAP_FOWNER).ValueOrDie()) { - EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, false)); - } + AutoCapability cap(CAP_FOWNER, false); // Change EGID. EXPECT_THAT( diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index ea219a091..fa6849f11 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -100,8 +100,8 @@ TEST(SymlinkTest, CanCreateSymlinkDir) { TEST(SymlinkTest, CannotCreateSymlinkInReadOnlyDir) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const std::string olddir = NewTempAbsPath(); ASSERT_THAT(mkdir(olddir.c_str(), 0444), SyscallSucceeds()); @@ -248,10 +248,10 @@ TEST(SymlinkTest, PwriteToSymlink) { EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); } -TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) { +TEST(SymlinkTest, SymlinkAtDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); @@ -299,10 +299,10 @@ TEST(SymlinkTest, ReadlinkAtDirWithOpath) { EXPECT_EQ(0, strncmp("/dangling", buf.data(), linksize)); } -TEST(SymlinkTest, ReadlinkAtDegradedPermissions_NoRandomSave) { +TEST(SymlinkTest, ReadlinkAtDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const std::string oldpath = NewTempAbsPathInDir(dir.path()); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 7341cf1a6..5bfdecc79 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -139,20 +139,16 @@ void TcpSocketTest::SetUp() { socklen_t addrlen = sizeof(addr); // Bind to some port then start listening. - ASSERT_THAT( - bind(listener_, reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(listener_, AsSockAddr(&addr), addrlen), SyscallSucceeds()); ASSERT_THAT(listen(listener_, SOMAXCONN), SyscallSucceeds()); // Get the address we're listening on, then connect to it. We need to do this // because we're allowing the stack to pick a port for us. - ASSERT_THAT(getsockname(listener_, reinterpret_cast<struct sockaddr*>(&addr), - &addrlen), + ASSERT_THAT(getsockname(listener_, AsSockAddr(&addr), &addrlen), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)( - first_fd, reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(first_fd, AsSockAddr(&addr), addrlen), SyscallSucceeds()); // Get the initial send buffer size. @@ -229,10 +225,9 @@ TEST_P(TcpSocketTest, SenderAddressIgnored) { socklen_t addrlen = sizeof(addr); memset(&addr, 0, sizeof(addr)); - ASSERT_THAT( - RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), 0, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(3)); + ASSERT_THAT(RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), 0, + AsSockAddr(&addr), &addrlen), + SyscallSucceedsWithValue(3)); // Check that addr remains zeroed-out. const char* ptr = reinterpret_cast<char*>(&addr); @@ -250,10 +245,9 @@ TEST_P(TcpSocketTest, SenderAddressIgnoredOnPeek) { socklen_t addrlen = sizeof(addr); memset(&addr, 0, sizeof(addr)); - ASSERT_THAT( - RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), MSG_PEEK, - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceedsWithValue(3)); + ASSERT_THAT(RetryEINTR(recvfrom)(second_fd, buf, sizeof(buf), MSG_PEEK, + AsSockAddr(&addr), &addrlen), + SyscallSucceedsWithValue(3)); // Check that addr remains zeroed-out. const char* ptr = reinterpret_cast<char*>(&addr); @@ -268,10 +262,9 @@ TEST_P(TcpSocketTest, SendtoAddressIgnored) { addr.ss_family = GetParam(); // FIXME(b/63803955) char data = '\0'; - EXPECT_THAT( - RetryEINTR(sendto)(first_fd, &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), sizeof(addr)), - SyscallSucceedsWithValue(1)); + EXPECT_THAT(RetryEINTR(sendto)(first_fd, &data, sizeof(data), 0, + AsSockAddr(&addr), sizeof(addr)), + SyscallSucceedsWithValue(1)); } TEST_P(TcpSocketTest, WritevZeroIovec) { @@ -331,7 +324,7 @@ TEST_P(TcpSocketTest, NonblockingLargeWrite) { // Test that a blocking write with a buffer that is larger than the send buffer // will block until the entire buffer is sent. -TEST_P(TcpSocketTest, BlockingLargeWrite_NoRandomSave) { +TEST_P(TcpSocketTest, BlockingLargeWrite) { // Allocate a buffer three times the size of the send buffer on the heap. We // do this as a vector to avoid allocating on the stack. int size = 3 * sendbuf_size_; @@ -415,7 +408,7 @@ TEST_P(TcpSocketTest, NonblockingLargeSend) { } // Same test as above, but calls send instead of write. -TEST_P(TcpSocketTest, BlockingLargeSend_NoRandomSave) { +TEST_P(TcpSocketTest, BlockingLargeSend) { // Allocate a buffer three times the size of the send buffer. We do this on // with a vector to avoid allocating on the stack. int size = 3 * sendbuf_size_; @@ -869,10 +862,9 @@ TEST_P(SimpleTcpSocketTest, SendtoWithAddressUnconnected) { sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); char data = '\0'; - EXPECT_THAT( - RetryEINTR(sendto)(fd, &data, sizeof(data), 0, - reinterpret_cast<sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(EPIPE)); + EXPECT_THAT(RetryEINTR(sendto)(fd, &data, sizeof(data), 0, AsSockAddr(&addr), + sizeof(addr)), + SyscallFailsWithErrno(EPIPE)); } TEST_P(SimpleTcpSocketTest, GetPeerNameUnconnected) { @@ -883,7 +875,7 @@ TEST_P(SimpleTcpSocketTest, GetPeerNameUnconnected) { sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen), + EXPECT_THAT(getpeername(fd, AsSockAddr(&addr), &addrlen), SyscallFailsWithErrno(ENOTCONN)); } @@ -974,24 +966,20 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRetry) { socklen_t addrlen = sizeof(addr); // Bind to some port but don't listen yet. - ASSERT_THAT( - bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(listener.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); // Get the address we're bound to, then connect to it. We need to do this // because we're allowing the stack to pick a port for us. - ASSERT_THAT(getsockname(listener.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + ASSERT_THAT(getsockname(listener.get(), AsSockAddr(&addr), &addrlen), SyscallSucceeds()); FileDescriptor connector = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); // Verify that connect fails. - ASSERT_THAT( - RetryEINTR(connect)(connector.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallFailsWithErrno(ECONNREFUSED)); + ASSERT_THAT(RetryEINTR(connect)(connector.get(), AsSockAddr(&addr), addrlen), + SyscallFailsWithErrno(ECONNREFUSED)); // Now start listening ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds()); @@ -1000,17 +988,14 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRetry) { // failed first connect should succeed. if (IsRunningOnGvisor()) { ASSERT_THAT( - RetryEINTR(connect)(connector.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), + RetryEINTR(connect)(connector.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(ECONNABORTED)); return; } // Verify that connect now succeeds. - ASSERT_THAT( - RetryEINTR(connect)(connector.get(), - reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(connector.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); // Accept the connection. const FileDescriptor accepted = @@ -1030,13 +1015,11 @@ PosixErrorOr<FileDescriptor> nonBlockingConnectNoListener(const int family, int b_sock; RETURN_ERROR_IF_SYSCALL_FAIL(b_sock = socket(family, sock_type, IPPROTO_TCP)); FileDescriptor b(b_sock); - EXPECT_THAT(bind(b.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + EXPECT_THAT(bind(b.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); // Get the address bound by the listening socket. - EXPECT_THAT( - getsockname(b.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(b.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); // Now create another socket and issue a connect on this one. This connect // should fail as there is no listener. @@ -1046,8 +1029,7 @@ PosixErrorOr<FileDescriptor> nonBlockingConnectNoListener(const int family, // Now connect to the bound address and this should fail as nothing // is listening on the bound address. - EXPECT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + EXPECT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(EINPROGRESS)); // Wait for the connect to fail. @@ -1078,8 +1060,7 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { opts &= ~O_NONBLOCK; EXPECT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds()); // Try connecting again. - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(ECONNABORTED)); } @@ -1094,8 +1075,7 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerRead) { unsigned char c; ASSERT_THAT(read(s.get(), &c, 1), SyscallFailsWithErrno(ECONNREFUSED)); ASSERT_THAT(read(s.get(), &c, 1), SyscallSucceedsWithValue(0)); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(ECONNABORTED)); } @@ -1111,12 +1091,11 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerPeek) { ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK), SyscallFailsWithErrno(ECONNREFUSED)); ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK), SyscallSucceedsWithValue(0)); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(ECONNABORTED)); } -TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) { +TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv) { // Initialize address to the loopback one. sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); @@ -1125,15 +1104,11 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) { const FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT( - (bind)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT((bind)(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); // Get the bound port. - ASSERT_THAT( - getsockname(s.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); constexpr int kBufSz = 1 << 20; // 1 MiB @@ -1168,7 +1143,7 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) { EXPECT_EQ(read_bytes, kBufSz); } -TEST_P(SimpleTcpSocketTest, SelfConnectSend_NoRandomSave) { +TEST_P(SimpleTcpSocketTest, SelfConnectSend) { // Initialize address to the loopback one. sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); @@ -1182,17 +1157,20 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSend_NoRandomSave) { setsockopt(s.get(), SOL_TCP, TCP_MAXSEG, &max_seg, sizeof(max_seg)), SyscallSucceeds()); - ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); // Get the bound port. - ASSERT_THAT( - getsockname(s.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(getsockname(s.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); - std::vector<char> writebuf(512 << 10); // 512 KiB. + // Ensure the write buffer is large enough not to block on a single write. + size_t write_size = 512 << 10; // 512 KiB. + EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_SNDBUF, &write_size, + sizeof(write_size)), + SyscallSucceedsWithValue(0)); + + std::vector<char> writebuf(write_size); // Try to send the whole thing. int n; @@ -1213,9 +1191,8 @@ void NonBlockingConnect(int family, int16_t pollMask) { socklen_t addrlen = sizeof(addr); // Bind to some port then start listening. - ASSERT_THAT( - bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(listener.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds()); @@ -1228,12 +1205,10 @@ void NonBlockingConnect(int family, int16_t pollMask) { opts |= O_NONBLOCK; ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds()); - ASSERT_THAT(getsockname(listener.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + ASSERT_THAT(getsockname(listener.get(), AsSockAddr(&addr), &addrlen), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(EINPROGRESS)); int t; @@ -1276,21 +1251,18 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) { socklen_t addrlen = sizeof(addr); // Bind to some port then start listening. - ASSERT_THAT( - bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(listener.get(), AsSockAddr(&addr), addrlen), + SyscallSucceeds()); ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds()); FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - ASSERT_THAT(getsockname(listener.get(), - reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + ASSERT_THAT(getsockname(listener.get(), AsSockAddr(&addr), &addrlen), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(EINPROGRESS)); int t; @@ -1305,12 +1277,10 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRemoteClose) { EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), SyscallSucceedsWithValue(1)); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(EISCONN)); } @@ -1325,8 +1295,7 @@ TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) { ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); socklen_t addrlen = sizeof(addr); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(ECONNREFUSED)); // Avoiding triggering save in destructor of s. @@ -1346,17 +1315,14 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) { ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); socklen_t bound_addrlen = sizeof(bound_addr); - ASSERT_THAT( - bind(bound_s.get(), reinterpret_cast<struct sockaddr*>(&bound_addr), - bound_addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(bound_s.get(), AsSockAddr(&bound_addr), bound_addrlen), + SyscallSucceeds()); // Get the addresses the socket is bound to because the port is chosen by the // stack. - ASSERT_THAT(getsockname(bound_s.get(), - reinterpret_cast<struct sockaddr*>(&bound_addr), - &bound_addrlen), - SyscallSucceeds()); + ASSERT_THAT( + getsockname(bound_s.get(), AsSockAddr(&bound_addr), &bound_addrlen), + SyscallSucceeds()); // Create, initialize, and bind the socket that is used to test connecting to // the non-listening port. @@ -1367,16 +1333,13 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) { ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); socklen_t client_addrlen = sizeof(client_addr); + ASSERT_THAT(bind(client_s.get(), AsSockAddr(&client_addr), client_addrlen), + SyscallSucceeds()); + ASSERT_THAT( - bind(client_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), - client_addrlen), + getsockname(client_s.get(), AsSockAddr(&client_addr), &client_addrlen), SyscallSucceeds()); - ASSERT_THAT(getsockname(client_s.get(), - reinterpret_cast<struct sockaddr*>(&client_addr), - &client_addrlen), - SyscallSucceeds()); - // Now the test: connect to the bound but not listening socket with the // client socket. The bound socket should return a RST and cause the client // socket to return an error and clean itself up immediately. @@ -1392,10 +1355,8 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) { // Test binding to the address from the client socket. This should be okay // if it was dropped correctly. - ASSERT_THAT( - bind(new_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr), - client_addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(new_s.get(), AsSockAddr(&client_addr), client_addrlen), + SyscallSucceeds()); // Attempt #2, with the new socket and reused addr our connect should fail in // the same way as before, not with an EADDRINUSE. @@ -1428,8 +1389,7 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) { ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); socklen_t addrlen = sizeof(addr); - ASSERT_THAT(RetryEINTR(connect)( - s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), SyscallFailsWithErrno(EINPROGRESS)); // We don't need to specify any events to get POLLHUP or POLLERR as these @@ -1720,8 +1680,7 @@ TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) { ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); socklen_t addrlen = sizeof(addr); - RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), - addrlen); + RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen); int buf_sz = 1 << 18; EXPECT_THAT( setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), @@ -2034,8 +1993,7 @@ TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) { socklen_t addrlen = sizeof(addr); // Bind to some port then start listening. - ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + ASSERT_THAT(bind(s.get(), AsSockAddr(&addr), addrlen), SyscallSucceeds()); ASSERT_THAT(listen(s.get(), SOMAXCONN), SyscallSucceeds()); @@ -2062,10 +2020,8 @@ TEST_P(SimpleTcpSocketTest, ConnectUnspecifiedAddress) { auto do_connect = [&addr, addrlen]() { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(addr.ss_family, SOCK_STREAM, IPPROTO_TCP)); - ASSERT_THAT( - RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), - addrlen), - SyscallFailsWithErrno(ECONNREFUSED)); + ASSERT_THAT(RetryEINTR(connect)(s.get(), AsSockAddr(&addr), addrlen), + SyscallFailsWithErrno(ECONNREFUSED)); }; do_connect(); // Test the v4 mapped address as well. diff --git a/test/syscalls/linux/timerfd.cc b/test/syscalls/linux/timerfd.cc index c4f8fdd7a..072c92797 100644 --- a/test/syscalls/linux/timerfd.cc +++ b/test/syscalls/linux/timerfd.cc @@ -114,7 +114,7 @@ TEST_P(TimerfdTest, BlockingRead) { EXPECT_GE((end_time - start_time) + TimerSlack(), kDelay); } -TEST_P(TimerfdTest, NonblockingRead_NoRandomSave) { +TEST_P(TimerfdTest, NonblockingRead) { constexpr absl::Duration kDelay = absl::Seconds(5); auto const tfd = diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index 17832c47d..0f08d9996 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -181,7 +181,7 @@ TEST(TruncateTest, FtruncateDir) { TEST(TruncateTest, TruncateNonWriteable) { // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to // always override write permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */)); EXPECT_THAT(truncate(temp_file.path().c_str(), 0), @@ -208,9 +208,9 @@ TEST(TruncateTest, FtruncateWithOpath) { // ftruncate(2) should succeed as long as the file descriptor is writeable, // regardless of whether the file permissions allow writing. -TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) { +TEST(TruncateTest, FtruncateWithoutWritePermission) { // Drop capabilities that allow us to override file permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + AutoCapability cap(CAP_DAC_OVERRIDE, false); // The only time we can open a file with flags forbidden by its permissions // is when we are creating the file. We cannot re-open with the same flags, @@ -230,7 +230,7 @@ TEST(TruncateTest, TruncateNonExist) { EXPECT_THAT(truncate("/foo/bar", 0), SyscallFailsWithErrno(ENOENT)); } -TEST(TruncateTest, FtruncateVirtualTmp_NoRandomSave) { +TEST(TruncateTest, FtruncateVirtualTmp) { auto temp_file = NewTempAbsPathInDir("/dev/shm"); const DisableSave ds; // Incompatible permissions. const FileDescriptor fd = diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc index 13ed0d68a..279fe342c 100644 --- a/test/syscalls/linux/tuntap.cc +++ b/test/syscalls/linux/tuntap.cc @@ -170,10 +170,10 @@ TEST(TuntapStaticTest, NetTunExists) { class TuntapTest : public ::testing::Test { protected: void SetUp() override { - have_net_admin_cap_ = + const bool have_net_admin_cap = ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)); - if (have_net_admin_cap_ && !IsRunningOnGvisor()) { + if (have_net_admin_cap && !IsRunningOnGvisor()) { // gVisor always creates enabled/up'd interfaces, while Linux does not (as // observed in b/110961832). Some of the tests require the Linux stack to // notify the socket of any link-address-resolution failures. Those @@ -183,21 +183,12 @@ class TuntapTest : public ::testing::Test { ASSERT_NO_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP)); } } - - void TearDown() override { - if (have_net_admin_cap_) { - // Bring back capability if we had dropped it in test case. - ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); - } - } - - bool have_net_admin_cap_; }; TEST_F(TuntapTest, CreateInterfaceNoCap) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + AutoCapability cap(CAP_NET_ADMIN, false); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); @@ -349,9 +340,8 @@ TEST_F(TuntapTest, PingKernel) { }; while (1) { inpkt r = {}; - int nread = read(fd.get(), &r, sizeof(r)); - EXPECT_THAT(nread, SyscallSucceeds()); - long unsigned int n = static_cast<long unsigned int>(nread); + size_t n; + EXPECT_THAT(n = read(fd.get(), &r, sizeof(r)), SyscallSucceeds()); if (n < sizeof(pihdr)) { std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol @@ -397,8 +387,7 @@ TEST_F(TuntapTest, SendUdpTriggersArpResolution) { .sin_port = htons(42), .sin_addr = {.s_addr = kTapPeerIPAddr}, }; - ASSERT_THAT(sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote), - sizeof(remote)), + ASSERT_THAT(sendto(sock, "hello", 5, 0, AsSockAddr(&remote), sizeof(remote)), SyscallSucceeds()); struct inpkt { @@ -409,9 +398,8 @@ TEST_F(TuntapTest, SendUdpTriggersArpResolution) { }; while (1) { inpkt r = {}; - int nread = read(fd.get(), &r, sizeof(r)); - EXPECT_THAT(nread, SyscallSucceeds()); - long unsigned int n = static_cast<long unsigned int>(nread); + size_t n; + EXPECT_THAT(n = read(fd.get(), &r, sizeof(r)), SyscallSucceeds()); if (n < sizeof(pihdr)) { std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol @@ -498,7 +486,7 @@ TEST_F(TuntapTest, WriteHangBug155928773) { .sin_addr = {.s_addr = kTapIPAddr}, }; // Return values do not matter in this test. - connect(sock, reinterpret_cast<struct sockaddr*>(&remote), sizeof(remote)); + connect(sock, AsSockAddr(&remote), sizeof(remote)); write(sock, "hello", 5); } diff --git a/test/syscalls/linux/udp_bind.cc b/test/syscalls/linux/udp_bind.cc index 6d92bdbeb..f68d78aa2 100644 --- a/test/syscalls/linux/udp_bind.cc +++ b/test/syscalls/linux/udp_bind.cc @@ -83,27 +83,24 @@ TEST_P(SendtoTest, Sendto) { ASSERT_NO_ERRNO_AND_VALUE(Socket(param.recv_domain, SOCK_DGRAM, 0)); if (param.send_addr_len > 0) { - ASSERT_THAT(bind(s1.get(), reinterpret_cast<sockaddr*>(¶m.send_addr), - param.send_addr_len), - SyscallSucceeds()); + ASSERT_THAT( + bind(s1.get(), AsSockAddr(¶m.send_addr), param.send_addr_len), + SyscallSucceeds()); } if (param.connect_addr_len > 0) { - ASSERT_THAT( - connect(s1.get(), reinterpret_cast<sockaddr*>(¶m.connect_addr), - param.connect_addr_len), - SyscallSucceeds()); + ASSERT_THAT(connect(s1.get(), AsSockAddr(¶m.connect_addr), + param.connect_addr_len), + SyscallSucceeds()); } - ASSERT_THAT(bind(s2.get(), reinterpret_cast<sockaddr*>(¶m.recv_addr), - param.recv_addr_len), + ASSERT_THAT(bind(s2.get(), AsSockAddr(¶m.recv_addr), param.recv_addr_len), SyscallSucceeds()); struct sockaddr_storage real_recv_addr = {}; socklen_t real_recv_addr_len = param.recv_addr_len; ASSERT_THAT( - getsockname(s2.get(), reinterpret_cast<sockaddr*>(&real_recv_addr), - &real_recv_addr_len), + getsockname(s2.get(), AsSockAddr(&real_recv_addr), &real_recv_addr_len), SyscallSucceeds()); ASSERT_EQ(real_recv_addr_len, param.recv_addr_len); @@ -116,23 +113,22 @@ TEST_P(SendtoTest, Sendto) { char buf[20] = {}; if (!param.sendto_errnos.empty()) { - ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr), - param.sendto_addr_len), - SyscallFailsWithErrno(ElementOf(param.sendto_errnos))); + ASSERT_THAT( + RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0, + AsSockAddr(&sendto_addr), param.sendto_addr_len), + SyscallFailsWithErrno(ElementOf(param.sendto_errnos))); return; } - ASSERT_THAT(RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&sendto_addr), - param.sendto_addr_len), - SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT( + RetryEINTR(sendto)(s1.get(), buf, sizeof(buf), 0, + AsSockAddr(&sendto_addr), param.sendto_addr_len), + SyscallSucceedsWithValue(sizeof(buf))); struct sockaddr_storage got_addr = {}; socklen_t got_addr_len = sizeof(sockaddr_storage); ASSERT_THAT(RetryEINTR(recvfrom)(s2.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&got_addr), - &got_addr_len), + AsSockAddr(&got_addr), &got_addr_len), SyscallSucceedsWithValue(sizeof(buf))); ASSERT_GT(got_addr_len, sizeof(sockaddr_in_common)); @@ -140,8 +136,7 @@ TEST_P(SendtoTest, Sendto) { struct sockaddr_storage sender_addr = {}; socklen_t sender_addr_len = sizeof(sockaddr_storage); - ASSERT_THAT(getsockname(s1.get(), reinterpret_cast<sockaddr*>(&sender_addr), - &sender_addr_len), + ASSERT_THAT(getsockname(s1.get(), AsSockAddr(&sender_addr), &sender_addr_len), SyscallSucceeds()); ASSERT_GT(sender_addr_len, sizeof(sockaddr_in_common)); diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 16eeeb5c6..29e174f71 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -138,7 +138,7 @@ void UdpSocketTest::SetUp() { bind_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); - bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + bind_addr_ = AsSockAddr(&bind_addr_storage_); sock_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); @@ -153,15 +153,13 @@ int UdpSocketTest::GetFamily() { PosixError UdpSocketTest::BindLoopback() { bind_addr_storage_ = InetLoopbackAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + struct sockaddr* bind_addr_ = AsSockAddr(&bind_addr_storage_); return BindSocket(bind_.get(), bind_addr_); } PosixError UdpSocketTest::BindAny() { bind_addr_storage_ = InetAnyAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + struct sockaddr* bind_addr_ = AsSockAddr(&bind_addr_storage_); return BindSocket(bind_.get(), bind_addr_); } @@ -195,7 +193,7 @@ socklen_t UdpSocketTest::GetAddrLength() { sockaddr_storage UdpSocketTest::InetAnyAddr() { struct sockaddr_storage addr; memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + AsSockAddr(&addr)->sa_family = GetFamily(); if (GetFamily() == AF_INET) { auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); @@ -213,7 +211,7 @@ sockaddr_storage UdpSocketTest::InetAnyAddr() { sockaddr_storage UdpSocketTest::InetLoopbackAddr() { struct sockaddr_storage addr; memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + AsSockAddr(&addr)->sa_family = GetFamily(); if (GetFamily() == AF_INET) { auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); @@ -229,7 +227,7 @@ sockaddr_storage UdpSocketTest::InetLoopbackAddr() { void UdpSocketTest::Disconnect(int sockfd) { sockaddr_storage addr_storage = InetAnyAddr(); - sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + sockaddr* addr = AsSockAddr(&addr_storage); socklen_t addrlen = sizeof(addr_storage); addr->sa_family = AF_UNSPEC; @@ -265,19 +263,16 @@ TEST_P(UdpSocketTest, Getsockname) { // Check that we're not bound. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(bind_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); struct sockaddr_storage any = InetAnyAddr(); - EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), - 0); + EXPECT_EQ(memcmp(&addr, AsSockAddr(&any), addrlen_), 0); ASSERT_NO_ERRNO(BindLoopback()); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(bind_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); @@ -289,17 +284,15 @@ TEST_P(UdpSocketTest, Getpeername) { // Check that we're not connected. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); // Connect, then check that we get the right address. ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); } @@ -322,9 +315,8 @@ TEST_P(UdpSocketTest, SendNotConnected) { // Check that we're bound now. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_NE(*Port(&addr), 0); } @@ -338,9 +330,8 @@ TEST_P(UdpSocketTest, ConnectBinds) { // Check that we're bound now. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_NE(*Port(&addr), 0); } @@ -361,9 +352,8 @@ TEST_P(UdpSocketTest, Bind) { // Check that we're still bound to the original address. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(bind_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); } @@ -383,7 +373,7 @@ TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) { // same time. struct sockaddr_storage addr_storage = InetLoopbackAddr(); socklen_t addrlen = sizeof(addr_storage); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds()); @@ -417,7 +407,7 @@ TEST_P(UdpSocketTest, ConnectSimultaneousWriteToInvalidPort) { // same time. struct sockaddr_storage addr_storage = InetLoopbackAddr(); socklen_t addrlen = sizeof(addr_storage); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds()); @@ -465,18 +455,17 @@ TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); // Send from sock to bind_. char buf[512]; RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT( + sendto(bind_.get(), buf, sizeof(buf), 0, AsSockAddr(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. char received[sizeof(buf)]; @@ -499,21 +488,18 @@ TEST_P(UdpSocketTest, Connect) { // Check that we're connected to the right peer. struct sockaddr_storage peer; socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen), + SyscallSucceeds()); EXPECT_EQ(peerlen, addrlen_); EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); // Try to bind after connect. struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(bind(sock_.get(), AsSockAddr(&any), addrlen_), + SyscallFailsWithErrno(EINVAL)); struct sockaddr_storage bind2_storage = InetLoopbackAddr(); - struct sockaddr* bind2_addr = - reinterpret_cast<struct sockaddr*>(&bind2_storage); + struct sockaddr* bind2_addr = AsSockAddr(&bind2_storage); FileDescriptor bind2 = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); @@ -523,9 +509,8 @@ TEST_P(UdpSocketTest, Connect) { // Check that peer name changed. peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen), + SyscallSucceeds()); EXPECT_EQ(peerlen, addrlen_); EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); } @@ -535,15 +520,13 @@ TEST_P(UdpSocketTest, ConnectAnyZero) { SKIP_IF(IsRunningOnGvisor()); struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); + EXPECT_THAT(connect(sock_.get(), AsSockAddr(&any), addrlen_), + SyscallSucceeds()); struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); } TEST_P(UdpSocketTest, ConnectAnyWithPort) { @@ -552,24 +535,21 @@ TEST_P(UdpSocketTest, ConnectAnyWithPort) { struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); } TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { // TODO(138658473): Enable when we can connect to port 0 with gVisor. SKIP_IF(IsRunningOnGvisor()); struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); + EXPECT_THAT(connect(sock_.get(), AsSockAddr(&any), addrlen_), + SyscallSucceeds()); struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); Disconnect(sock_.get()); } @@ -580,9 +560,8 @@ TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); @@ -595,7 +574,7 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) { // Bind to the next port above bind_. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); @@ -604,15 +583,14 @@ TEST_P(UdpSocketTest, DisconnectAfterBind) { struct sockaddr_storage unspec = {}; unspec.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), - sizeof(unspec.ss_family)), - SyscallSucceeds()); + EXPECT_THAT( + connect(sock_.get(), AsSockAddr(&unspec), sizeof(unspec.ss_family)), + SyscallSucceeds()); // Check that we're still bound. socklen_t addrlen = sizeof(unspec); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&unspec), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); @@ -626,7 +604,7 @@ TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { ASSERT_NO_ERRNO(BindAny()); struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); socklen_t addrlen = sizeof(addr); @@ -653,7 +631,7 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { ASSERT_NO_ERRNO(BindLoopback()); struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + struct sockaddr* any = AsSockAddr(&any_storage); SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); @@ -666,24 +644,22 @@ TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { // Check that we're still bound. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_EQ(memcmp(&addr, any, addrlen), 0); addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); } TEST_P(UdpSocketTest, Disconnect) { ASSERT_NO_ERRNO(BindLoopback()); struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + struct sockaddr* any = AsSockAddr(&any_storage); SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); @@ -694,29 +670,25 @@ TEST_P(UdpSocketTest, Disconnect) { // Check that we're connected to the right peer. struct sockaddr_storage peer; socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen), + SyscallSucceeds()); EXPECT_EQ(peerlen, addrlen_); EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); // Try to disconnect. struct sockaddr_storage addr = {}; addr.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), - sizeof(addr.ss_family)), + EXPECT_THAT(connect(sock_.get(), AsSockAddr(&addr), sizeof(addr.ss_family)), SyscallSucceeds()); peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(getpeername(sock_.get(), AsSockAddr(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); // Check that we're still bound. socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_EQ(*Port(&addr), *Port(&any_storage)); } @@ -733,7 +705,7 @@ TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { ASSERT_NO_ERRNO(BindLoopback()); struct sockaddr_storage addr_storage = InetAnyAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); @@ -881,7 +853,7 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { ASSERT_NO_ERRNO(BindLoopback()); // Connect to loopback:bind_addr_+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); @@ -910,7 +882,7 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { // Connect to loopback:bind_addr_port+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); @@ -961,7 +933,7 @@ TEST_P(UdpSocketTest, SendAndReceiveConnected) { // Connect to loopback:bind_addr_port+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); @@ -987,13 +959,13 @@ TEST_P(UdpSocketTest, ReceiveFromNotConnected) { // Connect to loopback:bind_addr_port+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); // Bind sock to loopback:bind_addr_port+2. struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + struct sockaddr* addr2 = AsSockAddr(&addr2_storage); SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); @@ -1013,7 +985,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) { // Bind sock to loopback:bind_addr_port+2. struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + struct sockaddr* addr2 = AsSockAddr(&addr2_storage); SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); @@ -1026,7 +998,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) { // Connect to loopback:bind_addr_port+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); @@ -1050,7 +1022,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) { // Connect to loopback:bind_addr_port+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); @@ -1069,7 +1041,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) { struct sockaddr_storage addr2; socklen_t addr2len = sizeof(addr2); EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr2), &addr2len), + AsSockAddr(&addr2), &addr2len), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); EXPECT_EQ(addr2len, addrlen_); @@ -1093,7 +1065,7 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { // Connect to loopback:bind_addr_port+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); @@ -1149,7 +1121,7 @@ TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { // Connect to loopback:bind_addr_port+1. struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + struct sockaddr* addr = AsSockAddr(&addr_storage); SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); @@ -1557,10 +1529,6 @@ TEST_P(UdpSocketTest, ErrorQueue) { #endif // __linux__ TEST_P(UdpSocketTest, SoTimestampOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - int v = -1; socklen_t optlen = sizeof(v); ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), @@ -1570,10 +1538,6 @@ TEST_P(UdpSocketTest, SoTimestampOffByDefault) { } TEST_P(UdpSocketTest, SoTimestamp) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - ASSERT_NO_ERRNO(BindLoopback()); ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); @@ -1583,8 +1547,8 @@ TEST_P(UdpSocketTest, SoTimestamp) { char buf[3]; // Send zero length packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); struct pollfd pfd = {bind_.get(), POLLIN, 0}; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), @@ -1614,9 +1578,13 @@ TEST_P(UdpSocketTest, SoTimestamp) { ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), - SyscallFailsWithErrno(ENOENT)); + // TODO(gvisor.dev/issue/1202): ioctl(SIOCGSTAMP) is not supported by + // hostinet. + if (!IsRunningWithHostinet()) { + // There should be nothing to get via ioctl. + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); + } } TEST_P(UdpSocketTest, WriteShutdownNotConnected) { @@ -1932,13 +1900,8 @@ TEST_P(UdpSocketTest, RecvBufLimits) { SyscallSucceeds()); } - // Now set the limit to min * 4. - int new_rcv_buf_sz = min * 4; - if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { - // Linux doubles the value specified so just set to min * 2. - new_rcv_buf_sz = min * 2; - } - + // Now set the limit to min * 2. + int new_rcv_buf_sz = min * 2; ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, sizeof(new_rcv_buf_sz)), SyscallSucceeds()); @@ -2051,68 +2014,57 @@ TEST_P(UdpSocketTest, SendToZeroPort) { // Sending to an invalid port should fail. SetPort(&addr, 0); - EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, - reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT( + sendto(sock_.get(), buf, sizeof(buf), 0, AsSockAddr(&addr), sizeof(addr)), + SyscallFailsWithErrno(EINVAL)); SetPort(&addr, 1234); - EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, - reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), - SyscallSucceedsWithValue(sizeof(buf))); + EXPECT_THAT( + sendto(sock_.get(), buf, sizeof(buf), 0, AsSockAddr(&addr), sizeof(addr)), + SyscallSucceedsWithValue(sizeof(buf))); } TEST_P(UdpSocketTest, ConnectToZeroPortUnbound) { struct sockaddr_storage addr = InetLoopbackAddr(); SetPort(&addr, 0); - ASSERT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), - SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_), + SyscallSucceeds()); } TEST_P(UdpSocketTest, ConnectToZeroPortBound) { struct sockaddr_storage addr = InetLoopbackAddr(); - ASSERT_NO_ERRNO( - BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr))); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), AsSockAddr(&addr))); SetPort(&addr, 0); - ASSERT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), - SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_), + SyscallSucceeds()); socklen_t len = sizeof(sockaddr_storage); - ASSERT_THAT( - getsockname(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), &len), - SyscallSucceeds()); + ASSERT_THAT(getsockname(sock_.get(), AsSockAddr(&addr), &len), + SyscallSucceeds()); ASSERT_EQ(len, addrlen_); } TEST_P(UdpSocketTest, ConnectToZeroPortConnected) { struct sockaddr_storage addr = InetLoopbackAddr(); - ASSERT_NO_ERRNO( - BindSocket(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr))); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), AsSockAddr(&addr))); // Connect to an address with non-zero port should succeed. - ASSERT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), - SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_), + SyscallSucceeds()); sockaddr_storage peername; socklen_t peerlen = sizeof(peername); - ASSERT_THAT( - getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername), - &peerlen), - SyscallSucceeds()); + ASSERT_THAT(getpeername(sock_.get(), AsSockAddr(&peername), &peerlen), + SyscallSucceeds()); ASSERT_EQ(peerlen, addrlen_); ASSERT_EQ(memcmp(&peername, &addr, addrlen_), 0); // However connect() to an address with port 0 will make the following // getpeername() fail. SetPort(&addr, 0); - ASSERT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen_), - SyscallSucceeds()); - ASSERT_THAT( - getpeername(sock_.get(), reinterpret_cast<struct sockaddr*>(&peername), - &peerlen), - SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(connect(sock_.get(), AsSockAddr(&addr), addrlen_), + SyscallSucceeds()); + ASSERT_THAT(getpeername(sock_.get(), AsSockAddr(&peername), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); } INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, @@ -2133,8 +2085,7 @@ TEST(UdpInet6SocketTest, ConnectInet4Sockaddr) { SyscallSucceeds()); sockaddr_storage sockname; socklen_t len = sizeof(sockaddr_storage); - ASSERT_THAT(getsockname(sock_.get(), - reinterpret_cast<struct sockaddr*>(&sockname), &len), + ASSERT_THAT(getsockname(sock_.get(), AsSockAddr(&sockname), &len), SyscallSucceeds()); ASSERT_EQ(sockname.ss_family, AF_INET6); ASSERT_EQ(len, sizeof(sockaddr_in6)); diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc index d8824b171..759ea4f53 100644 --- a/test/syscalls/linux/uname.cc +++ b/test/syscalls/linux/uname.cc @@ -76,9 +76,7 @@ TEST(UnameTest, SetNames) { } TEST(UnameTest, UnprivilegedSetNames) { - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - } + AutoCapability cap(CAP_SYS_ADMIN, false); EXPECT_THAT(sethostname("", 0), SyscallFailsWithErrno(EPERM)); EXPECT_THAT(setdomainname("", 0), SyscallFailsWithErrno(EPERM)); diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc index 061e2e0f1..75dcf4465 100644 --- a/test/syscalls/linux/unlink.cc +++ b/test/syscalls/linux/unlink.cc @@ -64,10 +64,10 @@ TEST(UnlinkTest, AtDir) { ASSERT_THAT(close(dirfd), SyscallSucceeds()); } -TEST(UnlinkTest, AtDirDegradedPermissions_NoRandomSave) { +TEST(UnlinkTest, AtDirDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -86,8 +86,8 @@ TEST(UnlinkTest, AtDirDegradedPermissions_NoRandomSave) { // Files cannot be unlinked if the parent is not writable and executable. TEST(UnlinkTest, ParentDegradedPermissions) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); @@ -162,7 +162,7 @@ TEST(UnlinkTest, AtFile) { EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds()); } -TEST(UnlinkTest, OpenFile_NoRandomSave) { +TEST(UnlinkTest, OpenFile) { // We can't save unlinked file unless they are on tmpfs. const DisableSave ds; auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc index e647d2896..e711d6657 100644 --- a/test/syscalls/linux/utimes.cc +++ b/test/syscalls/linux/utimes.cc @@ -225,7 +225,8 @@ void TestUtimensat(int dirFd, std::string const& path) { EXPECT_GE(mtime3, before); EXPECT_LE(mtime3, after); - EXPECT_EQ(atime3, mtime3); + // TODO(b/187074006): atime/mtime may differ with local_gofer_uncached. + // EXPECT_EQ(atime3, mtime3); } TEST(UtimensatTest, OnAbsPath) { diff --git a/test/syscalls/linux/verity_ioctl.cc b/test/syscalls/linux/verity_ioctl.cc new file mode 100644 index 000000000..822e16f3c --- /dev/null +++ b/test/syscalls/linux/verity_ioctl.cc @@ -0,0 +1,345 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdint.h> +#include <stdlib.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <time.h> + +#include <iomanip> +#include <sstream> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/fs_util.h" +#include "test/util/mount_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +#ifndef FS_IOC_ENABLE_VERITY +#define FS_IOC_ENABLE_VERITY 1082156677 +#endif + +#ifndef FS_IOC_MEASURE_VERITY +#define FS_IOC_MEASURE_VERITY 3221513862 +#endif + +#ifndef FS_VERITY_FL +#define FS_VERITY_FL 1048576 +#endif + +#ifndef FS_IOC_GETFLAGS +#define FS_IOC_GETFLAGS 2148034049 +#endif + +struct fsverity_digest { + __u16 digest_algorithm; + __u16 digest_size; /* input/output */ + __u8 digest[]; +}; + +constexpr int kMaxDigestSize = 64; +constexpr int kDefaultDigestSize = 32; +constexpr char kContents[] = "foobarbaz"; +constexpr char kMerklePrefix[] = ".merkle.verity."; +constexpr char kMerkleRootPrefix[] = ".merkleroot.verity."; + +class IoctlTest : public ::testing::Test { + protected: + void SetUp() override { + // Verity is implemented in VFS2. + SKIP_IF(IsRunningWithVFS1()); + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + // Mount a tmpfs file system, to be wrapped by a verity fs. + tmpfs_dir_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(mount("", tmpfs_dir_.path().c_str(), "tmpfs", 0, ""), + SyscallSucceeds()); + + // Create a new file in the tmpfs mount. + file_ = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(tmpfs_dir_.path(), kContents, 0777)); + filename_ = Basename(file_.path()); + } + + TempPath tmpfs_dir_; + TempPath file_; + std::string filename_; +}; + +// Provide a function to convert bytes to hex string, since +// absl::BytesToHexString does not seem to be compatible with golang +// hex.DecodeString used in verity due to zero-padding. +std::string BytesToHexString(uint8_t bytes[], int size) { + std::stringstream ss; + ss << std::hex; + for (int i = 0; i < size; ++i) { + ss << std::setw(2) << std::setfill('0') << static_cast<int>(bytes[i]); + } + return ss.str(); +} + +std::string MerklePath(absl::string_view path) { + return JoinPath(Dirname(path), + std::string(kMerklePrefix) + std::string(Basename(path))); +} + +std::string MerkleRootPath(absl::string_view path) { + return JoinPath(Dirname(path), + std::string(kMerkleRootPrefix) + std::string(Basename(path))); +} + +// Flip a random bit in the file represented by fd. +PosixError FlipRandomBit(int fd, int size) { + // Generate a random offset in the file. + srand(time(nullptr)); + unsigned int seed = 0; + int random_offset = rand_r(&seed) % size; + + // Read a random byte and flip a bit in it. + char buf[1]; + RETURN_ERROR_IF_SYSCALL_FAIL(PreadFd(fd, buf, 1, random_offset)); + buf[0] ^= 1; + RETURN_ERROR_IF_SYSCALL_FAIL(PwriteFd(fd, buf, 1, random_offset)); + return NoError(); +} + +// Mount a verity on the tmpfs and enable both the file and the direcotry. Then +// mount a new verity with measured root hash. +PosixErrorOr<std::string> MountVerity(std::string tmpfs_dir, + std::string filename) { + // Mount a verity fs on the existing tmpfs mount. + std::string mount_opts = "lower_path=" + tmpfs_dir; + ASSIGN_OR_RETURN_ERRNO(TempPath verity_dir, TempPath::CreateDir()); + RETURN_ERROR_IF_SYSCALL_FAIL( + mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str())); + + // Enable both the file and the directory. + ASSIGN_OR_RETURN_ERRNO( + auto fd, Open(JoinPath(verity_dir.path(), filename), O_RDONLY, 0777)); + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd.get(), FS_IOC_ENABLE_VERITY)); + ASSIGN_OR_RETURN_ERRNO(auto dir_fd, Open(verity_dir.path(), O_RDONLY, 0777)); + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(dir_fd.get(), FS_IOC_ENABLE_VERITY)); + + // Measure the root hash. + uint8_t digest_array[sizeof(struct fsverity_digest) + kMaxDigestSize] = {0}; + struct fsverity_digest* digest = + reinterpret_cast<struct fsverity_digest*>(digest_array); + digest->digest_size = kMaxDigestSize; + RETURN_ERROR_IF_SYSCALL_FAIL( + ioctl(dir_fd.get(), FS_IOC_MEASURE_VERITY, digest)); + + // Mount a verity fs with specified root hash. + mount_opts += + ",root_hash=" + BytesToHexString(digest->digest, digest->digest_size); + ASSIGN_OR_RETURN_ERRNO(TempPath verity_with_hash_dir, TempPath::CreateDir()); + RETURN_ERROR_IF_SYSCALL_FAIL(mount("", verity_with_hash_dir.path().c_str(), + "verity", 0, mount_opts.c_str())); + // Verity directories should not be deleted. Release the TempPath objects to + // prevent those directories from being deleted by the destructor. + verity_dir.release(); + return verity_with_hash_dir.release(); +} + +TEST_F(IoctlTest, Enable) { + // Mount a verity fs on the existing tmpfs mount. + std::string mount_opts = "lower_path=" + tmpfs_dir_.path(); + auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT( + mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()), + SyscallSucceeds()); + + // Confirm that the verity flag is absent. + int flag = 0; + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777)); + ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds()); + EXPECT_EQ(flag & FS_VERITY_FL, 0); + + // Enable the file and confirm that the verity flag is present. + ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds()); + ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds()); + EXPECT_EQ(flag & FS_VERITY_FL, FS_VERITY_FL); +} + +TEST_F(IoctlTest, Measure) { + // Mount a verity fs on the existing tmpfs mount. + std::string mount_opts = "lower_path=" + tmpfs_dir_.path(); + auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT( + mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()), + SyscallSucceeds()); + + // Confirm that the file cannot be measured. + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777)); + uint8_t digest_array[sizeof(struct fsverity_digest) + kMaxDigestSize] = {0}; + struct fsverity_digest* digest = + reinterpret_cast<struct fsverity_digest*>(digest_array); + digest->digest_size = kMaxDigestSize; + ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest), + SyscallFailsWithErrno(ENODATA)); + + // Enable the file and confirm that the file can be measured. + ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds()); + ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest), + SyscallSucceeds()); + EXPECT_EQ(digest->digest_size, kDefaultDigestSize); +} + +TEST_F(IoctlTest, Mount) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Make sure the file can be open and read in the mounted verity fs. + auto const verity_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir, filename_), O_RDONLY, 0777)); + char buf[sizeof(kContents)]; + EXPECT_THAT(ReadFd(verity_fd.get(), buf, sizeof(kContents)), + SyscallSucceeds()); +} + +TEST_F(IoctlTest, NonExistingFile) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Confirm that opening a non-existing file in the verity-enabled directory + // triggers the expected error instead of verification failure. + EXPECT_THAT( + open(JoinPath(verity_dir, filename_ + "abc").c_str(), O_RDONLY, 0777), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_F(IoctlTest, ModifiedFile) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Modify the file and check verification failure upon reading from it. + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(tmpfs_dir_.path(), filename_), O_RDWR, 0777)); + ASSERT_NO_ERRNO(FlipRandomBit(fd.get(), sizeof(kContents) - 1)); + + auto const verity_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir, filename_), O_RDONLY, 0777)); + char buf[sizeof(kContents)]; + EXPECT_THAT(pread(verity_fd.get(), buf, 16, 0), SyscallFailsWithErrno(EIO)); +} + +TEST_F(IoctlTest, ModifiedMerkle) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Modify the Merkle file and check verification failure upon opening the + // corresponding file. + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(MerklePath(JoinPath(tmpfs_dir_.path(), filename_)), O_RDWR, 0777)); + auto stat = ASSERT_NO_ERRNO_AND_VALUE(Fstat(fd.get())); + ASSERT_NO_ERRNO(FlipRandomBit(fd.get(), stat.st_size)); + + EXPECT_THAT(open(JoinPath(verity_dir, filename_).c_str(), O_RDONLY, 0777), + SyscallFailsWithErrno(EIO)); +} + +TEST_F(IoctlTest, ModifiedDirMerkle) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + // Modify the Merkle file for the parent directory and check verification + // failure upon opening the corresponding file. + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(MerkleRootPath(JoinPath(tmpfs_dir_.path(), "root")), O_RDWR, 0777)); + auto stat = ASSERT_NO_ERRNO_AND_VALUE(Fstat(fd.get())); + ASSERT_NO_ERRNO(FlipRandomBit(fd.get(), stat.st_size)); + + EXPECT_THAT(open(JoinPath(verity_dir, filename_).c_str(), O_RDONLY, 0777), + SyscallFailsWithErrno(EIO)); +} + +TEST_F(IoctlTest, Stat) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + struct stat st; + EXPECT_THAT(stat(JoinPath(verity_dir, filename_).c_str(), &st), + SyscallSucceeds()); +} + +TEST_F(IoctlTest, ModifiedStat) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + EXPECT_THAT(chmod(JoinPath(tmpfs_dir_.path(), filename_).c_str(), 0644), + SyscallSucceeds()); + struct stat st; + EXPECT_THAT(stat(JoinPath(verity_dir, filename_).c_str(), &st), + SyscallFailsWithErrno(EIO)); +} + +TEST_F(IoctlTest, DeleteFile) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + EXPECT_THAT(unlink(JoinPath(tmpfs_dir_.path(), filename_).c_str()), + SyscallSucceeds()); + EXPECT_THAT(open(JoinPath(verity_dir, filename_).c_str(), O_RDONLY, 0777), + SyscallFailsWithErrno(EIO)); +} + +TEST_F(IoctlTest, DeleteMerkle) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + EXPECT_THAT( + unlink(MerklePath(JoinPath(tmpfs_dir_.path(), filename_)).c_str()), + SyscallSucceeds()); + EXPECT_THAT(open(JoinPath(verity_dir, filename_).c_str(), O_RDONLY, 0777), + SyscallFailsWithErrno(EIO)); +} + +TEST_F(IoctlTest, RenameFile) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + std::string new_file_name = "renamed-" + filename_; + EXPECT_THAT(rename(JoinPath(tmpfs_dir_.path(), filename_).c_str(), + JoinPath(tmpfs_dir_.path(), new_file_name).c_str()), + SyscallSucceeds()); + EXPECT_THAT(open(JoinPath(verity_dir, filename_).c_str(), O_RDONLY, 0777), + SyscallFailsWithErrno(EIO)); +} + +TEST_F(IoctlTest, RenameMerkle) { + std::string verity_dir = + ASSERT_NO_ERRNO_AND_VALUE(MountVerity(tmpfs_dir_.path(), filename_)); + + std::string new_file_name = "renamed-" + filename_; + EXPECT_THAT( + rename(MerklePath(JoinPath(tmpfs_dir_.path(), filename_)).c_str(), + MerklePath(JoinPath(tmpfs_dir_.path(), new_file_name)).c_str()), + SyscallSucceeds()); + EXPECT_THAT(open(JoinPath(verity_dir, filename_).c_str(), O_RDONLY, 0777), + SyscallFailsWithErrno(EIO)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/verity_mount.cc b/test/syscalls/linux/verity_mount.cc new file mode 100644 index 000000000..e73dd5599 --- /dev/null +++ b/test/syscalls/linux/verity_mount.cc @@ -0,0 +1,53 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/mount.h> + +#include <iomanip> +#include <sstream> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Mount verity file system on an existing gofer mount. +TEST(MountTest, MountExisting) { + // Verity is implemented in VFS2. + SKIP_IF(IsRunningWithVFS1()); + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + // Mount a new tmpfs file system. + auto const tmpfs_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(mount("", tmpfs_dir.path().c_str(), "tmpfs", 0, ""), + SyscallSucceeds()); + + // Mount a verity file system on the existing gofer mount. + auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string opts = "lower_path=" + tmpfs_dir.path(); + EXPECT_THAT(mount("", verity_dir.path().c_str(), "verity", 0, opts.c_str()), + SyscallSucceeds()); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc index 19d05998e..1a282e371 100644 --- a/test/syscalls/linux/vfork.cc +++ b/test/syscalls/linux/vfork.cc @@ -87,7 +87,7 @@ TEST(VforkTest, ParentStopsUntilChildExits) { EXPECT_THAT(InForkedProcess(test), IsPosixErrorOkAndHolds(0)); } -TEST(VforkTest, ParentStopsUntilChildExecves_NoRandomSave) { +TEST(VforkTest, ParentStopsUntilChildExecves) { ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"}; char* const* const child_argv = owned_child_argv.get(); @@ -127,7 +127,7 @@ TEST(VforkTest, ParentStopsUntilChildExecves_NoRandomSave) { // A vfork child does not unstop the parent a second time when it exits after // exec. -TEST(VforkTest, ExecedChildExitDoesntUnstopParent_NoRandomSave) { +TEST(VforkTest, ExecedChildExitDoesntUnstopParent) { ExecveArray const owned_child_argv = {"/proc/self/exe", "--vfork_test_child"}; char* const* const child_argv = owned_child_argv.get(); diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index a953a55fe..c8a97df6b 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -107,10 +107,10 @@ TEST_F(XattrTest, XattrInvalidPrefix) { // Do not allow save/restore cycles after making the test file read-only, as // the restore will fail to open it with r/w permissions. -TEST_F(XattrTest, XattrReadOnly_NoRandomSave) { +TEST_F(XattrTest, XattrReadOnly) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); const char* path = test_file_name_.c_str(); const char name[] = "user.test"; @@ -138,10 +138,10 @@ TEST_F(XattrTest, XattrReadOnly_NoRandomSave) { // Do not allow save/restore cycles after making the test file write-only, as // the restore will fail to open it with r/w permissions. -TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) { +TEST_F(XattrTest, XattrWriteOnly) { // Drop capabilities that allow us to override file and directory permissions. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + AutoCapability cap1(CAP_DAC_OVERRIDE, false); + AutoCapability cap2(CAP_DAC_READ_SEARCH, false); DisableSave ds; ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR)); @@ -632,7 +632,7 @@ TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) { // Trusted namespace not supported in VFS1. SKIP_IF(IsRunningWithVFS1()); - // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + // TODO(b/166162845): Only gVisor tmpfs currently supports trusted namespace. SKIP_IF(IsRunningOnGvisor() && !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); @@ -680,9 +680,7 @@ TEST_F(XattrTest, TrustedNamespaceWithoutCapSysAdmin) { !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); // Drop CAP_SYS_ADMIN if we have it. - if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { - EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); - } + AutoCapability cap(CAP_SYS_ADMIN, false); const char* path = test_file_name_.c_str(); const char name[] = "trusted.test"; diff --git a/test/util/BUILD b/test/util/BUILD index e561f3daa..8985b54af 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -94,6 +94,7 @@ cc_library( ":file_descriptor", ":posix_error", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", gtest, ], ) @@ -136,11 +137,26 @@ cc_library( cc_library( name = "mount_util", testonly = 1, + srcs = ["mount_util.cc"], hdrs = ["mount_util.h"], deps = [ ":cleanup", ":posix_error", ":test_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + gtest, + ], +) + +cc_test( + name = "mount_util_test", + size = "small", + srcs = ["mount_util_test.cc"], + deps = [ + ":mount_util", + ":test_main", + ":test_util", gtest, ], ) @@ -368,3 +384,20 @@ cc_library( testonly = 1, hdrs = ["temp_umask.h"], ) + +cc_library( + name = "cgroup_util", + testonly = 1, + srcs = ["cgroup_util.cc"], + hdrs = ["cgroup_util.h"], + deps = [ + ":cleanup", + ":fs_util", + ":mount_util", + ":posix_error", + ":temp_path", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) diff --git a/test/util/capability_util.h b/test/util/capability_util.h index a03bc7e05..f2c370125 100644 --- a/test/util/capability_util.h +++ b/test/util/capability_util.h @@ -99,14 +99,23 @@ PosixErrorOr<bool> CanCreateUserNamespace(); class AutoCapability { public: AutoCapability(int cap, bool set) : cap_(cap), set_(set) { - EXPECT_NO_ERRNO(SetCapability(cap_, set_)); + const bool has = EXPECT_NO_ERRNO_AND_VALUE(HaveCapability(cap)); + if (set != has) { + EXPECT_NO_ERRNO(SetCapability(cap_, set_)); + applied_ = true; + } } - ~AutoCapability() { EXPECT_NO_ERRNO(SetCapability(cap_, !set_)); } + ~AutoCapability() { + if (applied_) { + EXPECT_NO_ERRNO(SetCapability(cap_, !set_)); + } + } private: int cap_; bool set_; + bool applied_ = false; }; } // namespace testing diff --git a/test/util/cgroup_util.cc b/test/util/cgroup_util.cc new file mode 100644 index 000000000..04d4f8de0 --- /dev/null +++ b/test/util/cgroup_util.cc @@ -0,0 +1,236 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/cgroup_util.h" + +#include <sys/syscall.h> +#include <unistd.h> + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "test/util/fs_util.h" +#include "test/util/mount_util.h" + +namespace gvisor { +namespace testing { + +Cgroup::Cgroup(std::string_view path) : cgroup_path_(path) { + id_ = ++Cgroup::next_id_; + std::cerr << absl::StreamFormat("[cg#%d] <= %s", id_, cgroup_path_) + << std::endl; +} + +PosixErrorOr<std::string> Cgroup::ReadControlFile( + absl::string_view name) const { + std::string buf; + RETURN_IF_ERRNO(GetContents(Relpath(name), &buf)); + + const std::string alias_path = absl::StrFormat("[cg#%d]/%s", id_, name); + std::cerr << absl::StreamFormat("<contents of %s>", alias_path) << std::endl; + std::cerr << buf; + std::cerr << absl::StreamFormat("<end of %s>", alias_path) << std::endl; + + return buf; +} + +PosixErrorOr<int64_t> Cgroup::ReadIntegerControlFile( + absl::string_view name) const { + ASSIGN_OR_RETURN_ERRNO(const std::string buf, ReadControlFile(name)); + ASSIGN_OR_RETURN_ERRNO(const int64_t val, Atoi<int64_t>(buf)); + return val; +} + +PosixError Cgroup::WriteControlFile(absl::string_view name, + const std::string& value) const { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(Relpath(name), O_WRONLY)); + RETURN_ERROR_IF_SYSCALL_FAIL(WriteFd(fd.get(), value.c_str(), value.size())); + return NoError(); +} + +PosixError Cgroup::WriteIntegerControlFile(absl::string_view name, + int64_t value) const { + return WriteControlFile(name, absl::StrCat(value)); +} + +PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Procs() const { + ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("cgroup.procs")); + return ParsePIDList(buf); +} + +PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Tasks() const { + ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("tasks")); + return ParsePIDList(buf); +} + +PosixError Cgroup::ContainsCallingProcess() const { + ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> procs, Procs()); + ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> tasks, Tasks()); + const pid_t pid = getpid(); + const pid_t tid = syscall(SYS_gettid); + if (!procs.contains(pid)) { + return PosixError( + ENOENT, absl::StrFormat("Cgroup doesn't contain process %d", pid)); + } + if (!tasks.contains(tid)) { + return PosixError(ENOENT, + absl::StrFormat("Cgroup doesn't contain task %d", tid)); + } + return NoError(); +} + +PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::ParsePIDList( + absl::string_view data) const { + absl::flat_hash_set<pid_t> res; + std::vector<absl::string_view> lines = absl::StrSplit(data, '\n'); + for (const std::string_view& line : lines) { + if (line.empty()) { + continue; + } + ASSIGN_OR_RETURN_ERRNO(const int32_t pid, Atoi<int32_t>(line)); + res.insert(static_cast<pid_t>(pid)); + } + return res; +} + +int64_t Cgroup::next_id_ = 0; + +PosixErrorOr<Cgroup> Mounter::MountCgroupfs(std::string mopts) { + ASSIGN_OR_RETURN_ERRNO(TempPath mountpoint, + TempPath::CreateDirIn(root_.path())); + ASSIGN_OR_RETURN_ERRNO( + Cleanup mount, Mount("none", mountpoint.path(), "cgroup", 0, mopts, 0)); + const std::string mountpath = mountpoint.path(); + std::cerr << absl::StreamFormat( + "Mount(\"none\", \"%s\", \"cgroup\", 0, \"%s\", 0) => OK", + mountpath, mopts) + << std::endl; + Cgroup cg = Cgroup(mountpath); + mountpoints_[cg.id()] = std::move(mountpoint); + mounts_[cg.id()] = std::move(mount); + return cg; +} + +PosixError Mounter::Unmount(const Cgroup& c) { + auto mount = mounts_.find(c.id()); + auto mountpoint = mountpoints_.find(c.id()); + + if (mount == mounts_.end() || mountpoint == mountpoints_.end()) { + return PosixError( + ESRCH, absl::StrFormat("No mount found for cgroupfs containing cg#%d", + c.id())); + } + + std::cerr << absl::StreamFormat("Unmount([cg#%d])", c.id()) << std::endl; + + // Simply delete the entries, their destructors will unmount and delete the + // mountpoint. Note the order is important to avoid errors: mount then + // mountpoint. + mounts_.erase(mount); + mountpoints_.erase(mountpoint); + + return NoError(); +} + +constexpr char kProcCgroupsHeader[] = + "#subsys_name\thierarchy\tnum_cgroups\tenabled"; + +PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>> +ProcCgroupsEntries() { + std::string content; + RETURN_IF_ERRNO(GetContents("/proc/cgroups", &content)); + + bool found_header = false; + absl::flat_hash_map<std::string, CgroupsEntry> entries; + std::vector<std::string> lines = absl::StrSplit(content, '\n'); + std::cerr << "<contents of /proc/cgroups>" << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + + if (!found_header) { + EXPECT_EQ(line, kProcCgroupsHeader); + found_header = true; + continue; + } + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/cgroups. + // + // Example entries, fields are tab separated in the real file: + // + // #subsys_name hierarchy num_cgroups enabled + // cpuset 12 35 1 + // cpu 3 222 1 + // ^ ^ ^ ^ + // 0 1 2 3 + + CgroupsEntry entry; + std::vector<std::string> fields = + StrSplit(line, absl::ByAnyChar(": \t"), absl::SkipEmpty()); + + entry.subsys_name = fields[0]; + ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[1])); + ASSIGN_OR_RETURN_ERRNO(entry.num_cgroups, Atoi<uint64_t>(fields[2])); + ASSIGN_OR_RETURN_ERRNO(const int enabled, Atoi<int>(fields[3])); + entry.enabled = enabled != 0; + + entries[entry.subsys_name] = entry; + } + std::cerr << "<end of /proc/cgroups>" << std::endl; + + return entries; +} + +PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>> +ProcPIDCgroupEntries(pid_t pid) { + const std::string path = absl::StrFormat("/proc/%d/cgroup", pid); + std::string content; + RETURN_IF_ERRNO(GetContents(path, &content)); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries; + std::vector<std::string> lines = absl::StrSplit(content, '\n'); + + std::cerr << absl::StreamFormat("<contents of %s>", path) << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/<pid>/cgroup. + // + // Example entries: + // + // 2:cpu:/path/to/cgroup + // 1:memory:/ + + PIDCgroupEntry entry; + std::vector<std::string> fields = + absl::StrSplit(line, absl::ByChar(':'), absl::SkipEmpty()); + + ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[0])); + entry.controllers = fields[1]; + entry.path = fields[2]; + + entries[entry.controllers] = entry; + } + std::cerr << absl::StreamFormat("<end of %s>", path) << std::endl; + + return entries; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/cgroup_util.h b/test/util/cgroup_util.h new file mode 100644 index 000000000..b797a8b24 --- /dev/null +++ b/test/util/cgroup_util.h @@ -0,0 +1,121 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_CGROUP_UTIL_H_ +#define GVISOR_TEST_UTIL_CGROUP_UTIL_H_ + +#include <unistd.h> + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +// Cgroup represents a cgroup directory on a mounted cgroupfs. +class Cgroup { + public: + Cgroup(std::string_view path); + + uint64_t id() const { return id_; } + + const std::string& Path() const { return cgroup_path_; } + + std::string Relpath(absl::string_view leaf) const { + return JoinPath(cgroup_path_, leaf); + } + + // Returns the contents of a cgroup control file with the given name. + PosixErrorOr<std::string> ReadControlFile(absl::string_view name) const; + + // Reads the contents of a cgroup control with the given name, and attempts + // to parse it as an integer. + PosixErrorOr<int64_t> ReadIntegerControlFile(absl::string_view name) const; + + // Writes a string to a cgroup control file. + PosixError WriteControlFile(absl::string_view name, + const std::string& value) const; + + // Writes an integer value to a cgroup control file. + PosixError WriteIntegerControlFile(absl::string_view name, + int64_t value) const; + + // Returns the thread ids of the leaders of thread groups managed by this + // cgroup. + PosixErrorOr<absl::flat_hash_set<pid_t>> Procs() const; + + PosixErrorOr<absl::flat_hash_set<pid_t>> Tasks() const; + + // ContainsCallingProcess checks whether the calling process is part of the + PosixError ContainsCallingProcess() const; + + private: + PosixErrorOr<absl::flat_hash_set<pid_t>> ParsePIDList( + absl::string_view data) const; + + static int64_t next_id_; + int64_t id_; + const std::string cgroup_path_; +}; + +// Mounter is a utility for creating cgroupfs mounts. It automatically manages +// the lifetime of created mounts. +class Mounter { + public: + Mounter(TempPath root) : root_(std::move(root)) {} + + PosixErrorOr<Cgroup> MountCgroupfs(std::string mopts); + + PosixError Unmount(const Cgroup& c); + + private: + // The destruction order of these members avoids errors during cleanup. We + // first unmount (by executing the mounts_ cleanups), then delete the + // mountpoint subdirs, then delete the root. + TempPath root_; + absl::flat_hash_map<int64_t, TempPath> mountpoints_; + absl::flat_hash_map<int64_t, Cleanup> mounts_; +}; + +// Represents a line from /proc/cgroups. +struct CgroupsEntry { + std::string subsys_name; + uint32_t hierarchy; + uint64_t num_cgroups; + bool enabled; +}; + +// Returns a parsed representation of /proc/cgroups. +PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>> +ProcCgroupsEntries(); + +// Represents a line from /proc/<pid>/cgroup. +struct PIDCgroupEntry { + uint32_t hierarchy; + std::string controllers; + std::string path; +}; + +// Returns a parsed representation of /proc/<pid>/cgroup. +PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>> +ProcPIDCgroupEntries(pid_t pid); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_UTIL_CGROUP_UTIL_H_ diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 5f1ce0d8a..483ae848d 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -28,6 +28,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -366,6 +368,48 @@ PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath, return files; } +PosixError DirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude) { + ASSIGN_OR_RETURN_ERRNO(auto listing, ListDir(path, false)); + + for (auto& expected_entry : expect) { + auto cursor = std::find(listing.begin(), listing.end(), expected_entry); + if (cursor == listing.end()) { + return PosixError(ENOENT, absl::StrFormat("Failed to find '%s' in '%s'", + expected_entry, path)); + } + } + for (auto& excluded_entry : exclude) { + auto cursor = std::find(listing.begin(), listing.end(), excluded_entry); + if (cursor != listing.end()) { + return PosixError(ENOENT, absl::StrCat("File '", excluded_entry, + "' found in path '", path, "'")); + } + } + return NoError(); +} + +PosixError EventuallyDirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude) { + constexpr int kRetryCount = 100; + const absl::Duration kRetryDelay = absl::Milliseconds(100); + + for (int i = 0; i < kRetryCount; ++i) { + auto res = DirContains(path, expect, exclude); + if (res.ok()) { + return res; + } + if (i < kRetryCount - 1) { + // Sleep if this isn't the final iteration. + absl::SleepFor(kRetryDelay); + } + } + return PosixError(ETIMEDOUT, + "Timed out while waiting for directory to contain files "); +} + PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs, int* undeleted_files) { ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path)); diff --git a/test/util/fs_util.h b/test/util/fs_util.h index 2190c3bca..bb2d1d3c8 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -129,6 +129,18 @@ PosixError WalkTree( PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath, bool skipdots); +// Check that a directory contains children nodes named in expect, and does not +// contain any children nodes named in exclude. +PosixError DirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude); + +// Same as DirContains, but adds a retry. Suitable for checking a directory +// being modified asynchronously. +PosixError EventuallyDirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude); + // Attempt to recursively delete a directory or file. Returns an error and // the number of undeleted directories and files. If either // undeleted_dirs or undeleted_files is nullptr then it will not be used. diff --git a/test/util/mount_util.cc b/test/util/mount_util.cc new file mode 100644 index 000000000..48640d6a1 --- /dev/null +++ b/test/util/mount_util.cc @@ -0,0 +1,176 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/mount_util.h" + +#include <sys/syscall.h> +#include <unistd.h> + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<std::vector<ProcMountsEntry>> ProcSelfMountsEntries() { + std::string content; + RETURN_IF_ERRNO(GetContents("/proc/self/mounts", &content)); + return ProcSelfMountsEntriesFrom(content); +} + +PosixErrorOr<std::vector<ProcMountsEntry>> ProcSelfMountsEntriesFrom( + const std::string& content) { + std::vector<ProcMountsEntry> entries; + std::vector<std::string> lines = + absl::StrSplit(content, absl::ByChar('\n'), absl::AllowEmpty()); + std::cerr << "<contents of /proc/self/mounts>" << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/self/mounts. + // + // Example entries: + // + // sysfs /sys sysfs rw,nosuid,nodev,noexec,relatime 0 0 + // proc /proc proc rw,nosuid,nodev,noexec,relatime 0 0 + // ^ ^ ^ ^ ^ ^ + // 0 1 2 3 4 5 + + ProcMountsEntry entry; + std::vector<std::string> fields = + absl::StrSplit(line, absl::ByChar(' '), absl::AllowEmpty()); + if (fields.size() != 6) { + return PosixError( + EINVAL, absl::StrFormat("Not enough tokens, got %d, content: <<%s>>", + fields.size(), content)); + } + + entry.spec = fields[0]; + entry.mount_point = fields[1]; + entry.fstype = fields[2]; + entry.mount_opts = fields[3]; + ASSIGN_OR_RETURN_ERRNO(entry.dump, Atoi<uint32_t>(fields[4])); + ASSIGN_OR_RETURN_ERRNO(entry.fsck, Atoi<uint32_t>(fields[5])); + + entries.push_back(entry); + } + std::cerr << "<end of /proc/self/mounts>" << std::endl; + + return entries; +} + +PosixErrorOr<std::vector<ProcMountInfoEntry>> ProcSelfMountInfoEntries() { + std::string content; + RETURN_IF_ERRNO(GetContents("/proc/self/mountinfo", &content)); + return ProcSelfMountInfoEntriesFrom(content); +} + +PosixErrorOr<std::vector<ProcMountInfoEntry>> ProcSelfMountInfoEntriesFrom( + const std::string& content) { + std::vector<ProcMountInfoEntry> entries; + std::vector<std::string> lines = + absl::StrSplit(content, absl::ByChar('\n'), absl::AllowEmpty()); + std::cerr << "<contents of /proc/self/mountinfo>" << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/self/mountinfo. + // + // Example entries: + // + // 22 28 0:20 / /sys rw,relatime shared:7 - sysfs sysfs rw + // 23 28 0:21 / /proc rw,relatime shared:14 - proc proc rw + // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ + // 0 1 2 3 4 5 6 7 8 9 10 + + ProcMountInfoEntry entry; + std::vector<std::string> fields = + absl::StrSplit(line, absl::ByChar(' '), absl::AllowEmpty()); + if (fields.size() < 10 || fields.size() > 11) { + return PosixError( + EINVAL, absl::StrFormat( + "Unexpected number of tokens, got %d, content: <<%s>>", + fields.size(), content)); + } + + ASSIGN_OR_RETURN_ERRNO(entry.id, Atoi<uint64_t>(fields[0])); + ASSIGN_OR_RETURN_ERRNO(entry.parent_id, Atoi<uint64_t>(fields[1])); + + std::vector<std::string> devs = + absl::StrSplit(fields[2], absl::ByChar(':')); + if (devs.size() != 2) { + return PosixError( + EINVAL, + absl::StrFormat( + "Failed to parse dev number field %s: too many tokens, got %d", + fields[2], devs.size())); + } + ASSIGN_OR_RETURN_ERRNO(entry.major, Atoi<dev_t>(devs[0])); + ASSIGN_OR_RETURN_ERRNO(entry.minor, Atoi<dev_t>(devs[1])); + + entry.root = fields[3]; + entry.mount_point = fields[4]; + entry.mount_opts = fields[5]; + + // The optional field (fields[6]) may or may not be present. We know based + // on the total number of tokens. + int off = -1; + if (fields.size() == 11) { + entry.optional = fields[6]; + off = 0; + } + // Field 7 is the optional field terminator char '-'. + entry.fstype = fields[8 + off]; + entry.mount_source = fields[9 + off]; + entry.super_opts = fields[10 + off]; + + entries.push_back(entry); + } + std::cerr << "<end of /proc/self/mountinfo>" << std::endl; + + return entries; +} + +absl::flat_hash_map<std::string, std::string> ParseMountOptions( + std::string mopts) { + absl::flat_hash_map<std::string, std::string> entries; + const std::vector<std::string> tokens = + absl::StrSplit(mopts, absl::ByChar(','), absl::AllowEmpty()); + for (const auto& token : tokens) { + std::vector<std::string> kv = + absl::StrSplit(token, absl::MaxSplits('=', 1)); + if (kv.size() == 2) { + entries[kv[0]] = kv[1]; + } else if (kv.size() == 1) { + entries[kv[0]] = ""; + } else { + TEST_CHECK_MSG( + false, + absl::StrFormat( + "Invalid mount option token '%s', was split into %d subtokens", + token, kv.size()) + .c_str()); + } + } + return entries; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/mount_util.h b/test/util/mount_util.h index 09e2281eb..3f8a1c0f1 100644 --- a/test/util/mount_util.h +++ b/test/util/mount_util.h @@ -22,6 +22,7 @@ #include <string> #include "gmock/gmock.h" +#include "absl/container/flat_hash_map.h" #include "test/util/cleanup.h" #include "test/util/posix_error.h" #include "test/util/test_util.h" @@ -45,6 +46,53 @@ inline PosixErrorOr<Cleanup> Mount(const std::string& source, }); } +struct ProcMountsEntry { + std::string spec; + std::string mount_point; + std::string fstype; + std::string mount_opts; + uint32_t dump; + uint32_t fsck; +}; + +// ProcSelfMountsEntries returns a parsed representation of /proc/self/mounts. +PosixErrorOr<std::vector<ProcMountsEntry>> ProcSelfMountsEntries(); + +// ProcSelfMountsEntries returns a parsed representation of mounts from the +// provided content. +PosixErrorOr<std::vector<ProcMountsEntry>> ProcSelfMountsEntriesFrom( + const std::string& content); + +struct ProcMountInfoEntry { + uint64_t id; + uint64_t parent_id; + dev_t major; + dev_t minor; + std::string root; + std::string mount_point; + std::string mount_opts; + std::string optional; + std::string fstype; + std::string mount_source; + std::string super_opts; +}; + +// ProcSelfMountInfoEntries returns a parsed representation of +// /proc/self/mountinfo. +PosixErrorOr<std::vector<ProcMountInfoEntry>> ProcSelfMountInfoEntries(); + +// ProcSelfMountInfoEntriesFrom returns a parsed representation of +// mountinfo from the provided content. +PosixErrorOr<std::vector<ProcMountInfoEntry>> ProcSelfMountInfoEntriesFrom( + const std::string&); + +// Interprets the input string mopts as a comma separated list of mount +// options. A mount option can either be just a value, or a key=value pair. For +// example, the string "rw,relatime,fd=7" will be parsed into a map like { "rw": +// "", "relatime": "", "fd": "7" }. +absl::flat_hash_map<std::string, std::string> ParseMountOptions( + std::string mopts); + } // namespace testing } // namespace gvisor diff --git a/test/util/mount_util_test.cc b/test/util/mount_util_test.cc new file mode 100644 index 000000000..2bcb6cc43 --- /dev/null +++ b/test/util/mount_util_test.cc @@ -0,0 +1,47 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/mount_util.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +TEST(ParseMounts, Mounts) { + auto entries = ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountsEntriesFrom( + R"proc(sysfs /sys sysfs rw,nosuid,nodev,noexec,relatime 0 0 +proc /proc proc rw,nosuid,nodev,noexec,relatime 0 0 + /mnt tmpfs rw,noexec 0 0 +)proc")); + EXPECT_EQ(entries.size(), 3); +} + +TEST(ParseMounts, MountInfo) { + auto entries = ASSERT_NO_ERRNO_AND_VALUE(ProcSelfMountInfoEntriesFrom( + R"proc(22 28 0:20 / /sys rw,relatime shared:7 - sysfs sysfs rw +23 28 0:21 / /proc rw,relatime shared:14 - proc proc rw +2007 8844 0:278 / /mnt rw,noexec - tmpfs rw,mode=123,uid=268601820,gid=5000 +)proc")); + EXPECT_EQ(entries.size(), 3); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/util/posix_error.h b/test/util/posix_error.h index 27557ad44..9ca09b77c 100644 --- a/test/util/posix_error.h +++ b/test/util/posix_error.h @@ -438,6 +438,13 @@ IsPosixErrorOkAndHolds(InnerMatcher&& inner_matcher) { std::move(_expr_result).ValueOrDie(); \ }) +#define EXPECT_NO_ERRNO_AND_VALUE(expr) \ + ({ \ + auto _expr_result = (expr); \ + EXPECT_NO_ERRNO(_expr_result); \ + std::move(_expr_result).ValueOrDie(); \ + }) + } // namespace testing } // namespace gvisor diff --git a/test/util/save_util.cc b/test/util/save_util.cc index 59d47e06e..3e724d99b 100644 --- a/test/util/save_util.cc +++ b/test/util/save_util.cc @@ -27,23 +27,13 @@ namespace gvisor { namespace testing { namespace { -std::atomic<absl::optional<bool>> cooperative_save_present; -std::atomic<absl::optional<bool>> random_save_present; +std::atomic<absl::optional<bool>> save_present; -bool CooperativeSavePresent() { - auto present = cooperative_save_present.load(); +bool SavePresent() { + auto present = save_present.load(); if (!present.has_value()) { - present = getenv("GVISOR_COOPERATIVE_SAVE_TEST") != nullptr; - cooperative_save_present.store(present); - } - return present.value(); -} - -bool RandomSavePresent() { - auto present = random_save_present.load(); - if (!present.has_value()) { - present = getenv("GVISOR_RANDOM_SAVE_TEST") != nullptr; - random_save_present.store(present); + present = getenv("GVISOR_SAVE_TEST") != nullptr; + save_present.store(present); } return present.value(); } @@ -52,12 +42,10 @@ std::atomic<int> save_disable; } // namespace -bool IsRunningWithSaveRestore() { - return CooperativeSavePresent() || RandomSavePresent(); -} +bool IsRunningWithSaveRestore() { return SavePresent(); } void MaybeSave() { - if (CooperativeSavePresent() && save_disable.load() == 0) { + if (SavePresent() && save_disable.load() == 0) { internal::DoCooperativeSave(); } } diff --git a/test/util/test_util.h b/test/util/test_util.h index 876ff58db..bcbb388ed 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -272,10 +272,15 @@ PosixErrorOr<std::vector<OpenFd>> GetOpenFDs(); // Returns the number of hard links to a path. PosixErrorOr<uint64_t> Links(const std::string& path); +inline uint64_t ns_elapsed(const struct timespec& begin, + const struct timespec& end) { + return (end.tv_sec - begin.tv_sec) * 1000000000 + + (end.tv_nsec - begin.tv_nsec); +} + inline uint64_t ms_elapsed(const struct timespec& begin, const struct timespec& end) { - return (end.tv_sec - begin.tv_sec) * 1000 + - (end.tv_nsec - begin.tv_nsec) / 1000000; + return ns_elapsed(begin, end) / 1000000; } namespace internal { diff --git a/tools/BUILD b/tools/BUILD index faf310676..3861ff2a5 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -9,3 +9,11 @@ bzl_library( "//:sandbox", ], ) + +bzl_library( + name = "deps_bzl", + srcs = ["deps.bzl"], + visibility = [ + "//:sandbox", + ], +) diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD index c2c1287a1..24e6f8a94 100644 --- a/tools/bazeldefs/BUILD +++ b/tools/bazeldefs/BUILD @@ -1,6 +1,9 @@ -load("//tools:defs.bzl", "bzl_library") +load("//tools:defs.bzl", "bzl_library", "go_proto_library") -package(licenses = ["notice"]) +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) bzl_library( name = "platforms_bzl", @@ -45,3 +48,9 @@ genrule( stamp = True, visibility = ["//:sandbox"], ) + +go_proto_library( + name = "worker_protocol_go_proto", + importpath = "gvisor.dev/bazel/worker_protocol_go_proto", + proto = "@bazel_tools//src/main/protobuf:worker_protocol_proto", +) diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl index bcd8cffe7..da027846b 100644 --- a/tools/bazeldefs/go.bzl +++ b/tools/bazeldefs/go.bzl @@ -8,8 +8,13 @@ load("//tools/bazeldefs:defs.bzl", "select_arch", "select_system") gazelle = _gazelle go_embed_data = _go_embed_data go_path = _go_path +bazel_worker_proto = "//tools/bazeldefs:worker_protocol_go_proto" def _go_proto_or_grpc_library(go_library_func, name, **kwargs): + if "importpath" in kwargs: + # If importpath is explicit, pass straight through. + go_library_func(name = name, **kwargs) + return deps = [ dep.replace("_proto", "_go_proto") for dep in (kwargs.pop("deps", []) or []) @@ -132,7 +137,7 @@ def go_context(ctx, goos = None, goarch = None, std = False): runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs), goos = go_ctx.sdk.goos, goarch = go_ctx.sdk.goarch, - tags = go_ctx.tags, + gotags = go_ctx.tags, ) def select_goarch(): diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD index 1cea9e1c9..2b116fe0d 100644 --- a/tools/bigquery/BUILD +++ b/tools/bigquery/BUILD @@ -6,11 +6,13 @@ go_library( name = "bigquery", testonly = 1, srcs = ["bigquery.go"], + nogo = False, # FIXME(b/184974218): Analysis failing for cloud libraries. visibility = [ "//:sandbox", ], deps = [ "@com_google_cloud_go_bigquery//:go_default_library", "@org_golang_google_api//option:go_default_library", + "@org_golang_x_oauth2//:go_default_library", ], ) diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index a4ca93ec2..935154acc 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -119,6 +119,14 @@ func NewBenchmark(name string, iters int) *Benchmark { } } +// NewBenchmarkWithMetric creates a new sending to BigQuery, initialized with a +// single iteration and single metric. +func NewBenchmarkWithMetric(name, metric, unit string, value float64) *Benchmark { + b := NewBenchmark(name, 1) + b.AddMetric(metric, unit, value) + return b +} + // NewSuite initializes a new Suite. func NewSuite(name string, official bool) *Suite { return &Suite{ diff --git a/tools/checklocks/checklocks.go b/tools/checklocks/checklocks.go index 4ec2918f6..1e877d394 100644 --- a/tools/checklocks/checklocks.go +++ b/tools/checklocks/checklocks.go @@ -563,9 +563,7 @@ func (pc *passContext) checkFunctionCall(call *ssa.Call, isExempted bool, lh *lo if !ok { return } - if fn.Object() == nil { - log.Warningf("fn w/ nil Object is: %+v", fn) return } @@ -579,7 +577,6 @@ func (pc *passContext) checkFunctionCall(call *ssa.Call, isExempted bool, lh *lo r := (*call.Value().Operands(nil)[fg.ParameterNumber+1]) guardObj := findField(r, fg.FieldNumber) if guardObj == nil { - log.Infof("guardObj nil but funcFact: %+v", funcFact) continue } var fieldFacts lockFieldFacts diff --git a/tools/defs.bzl b/tools/defs.bzl index d2c697c0d..27542a2f5 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -10,7 +10,7 @@ load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_ load("//tools/nogo:defs.bzl", "nogo_test") load("//tools/bazeldefs:defs.bzl", _arch_genrule = "arch_genrule", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _version = "version") load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option") -load("//tools/bazeldefs:go.bzl", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_rule = "go_rule", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos") +load("//tools/bazeldefs:go.bzl", _bazel_worker_proto = "bazel_worker_proto", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_rule = "go_rule", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos") load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") @@ -47,6 +47,8 @@ go_path = _go_path select_goos = _select_goos select_goarch = _select_goarch go_embed_data = _go_embed_data +go_proto_library = _go_proto_library +bazel_worker_proto = _bazel_worker_proto # Packaging rules. pkg_deb = _pkg_deb diff --git a/tools/deps.bzl b/tools/deps.bzl new file mode 100644 index 000000000..91442617c --- /dev/null +++ b/tools/deps.bzl @@ -0,0 +1,119 @@ +"""Rules for dependency checking.""" + +# DepsInfo provides a list of dependencies found when building a target. +DepsInfo = provider( + "lists dependencies encountered while building", + fields = { + "nodes": "a dict from targets to a list of their dependencies", + }, +) + +def _deps_check_impl(target, ctx): + # Check the target's dependencies and add any of our own deps. + deps = [] + for dep in ctx.rule.attr.deps: + deps.append(dep) + nodes = {} + if len(deps) != 0: + nodes[target] = deps + + # Keep and propagate each dep's providers. + for dep in ctx.rule.attr.deps: + nodes.update(dep[DepsInfo].nodes) + + return [DepsInfo(nodes = nodes)] + +_deps_check = aspect( + implementation = _deps_check_impl, + attr_aspects = ["deps"], +) + +def _is_allowed(target, allowlist, prefixes): + # Check for allowed prefixes. + for prefix in prefixes: + workspace, pfx = prefix.split("//", 1) + if len(workspace) > 0 and workspace[0] == "@": + workspace = workspace[1:] + if target.workspace_name == workspace and target.package.startswith(pfx): + return True + + # Check the allowlist. + for allowed in allowlist: + if target == allowed.label: + return True + + return False + +def _deps_test_impl(ctx): + nodes = {} + for target in ctx.attr.targets: + for (node_target, node_deps) in target[DepsInfo].nodes.items(): + # Ignore any disallowed targets. This generates more useful error + # messages. Consider the case where A dependes on B and B depends + # on C, and both B and C are disallowed. Avoid emitting an error + # that B depends on C, when the real issue is that A depends on B. + if not _is_allowed(node_target.label, ctx.attr.allowed, ctx.attr.allowed_prefixes) and node_target.label != target.label: + continue + bad_deps = [] + for dep in node_deps: + if not _is_allowed(dep.label, ctx.attr.allowed, ctx.attr.allowed_prefixes): + bad_deps.append(dep) + if len(bad_deps) > 0: + nodes[node_target] = bad_deps + + # If there aren't any violations, write a passing test. + if len(nodes) == 0: + ctx.actions.write( + output = ctx.outputs.executable, + content = "#!/bin/bash\n\nexit 0\n", + ) + return [] + + # If we're here, we've found at least one violation. + script_lines = [ + "#!/bin/bash", + "echo Invalid dependencies found. If you\\'re sure you want to add dependencies,", + "echo modify this target.", + "echo", + ] + + # List the violations. + for target, deps in nodes.items(): + script_lines.append( + 'echo "{target} depends on:"'.format(target = target.label), + ) + for dep in deps: + script_lines.append('echo "\t{dep}"'.format(dep = dep.label)) + + # The test must fail. + script_lines.append("exit 1\n") + + ctx.actions.write( + output = ctx.outputs.executable, + content = "\n".join(script_lines), + ) + return [] + +# Checks that targets only depend on an allowlist of other targets. Targets can +# be specified directly, or prefixes can be used to allow entire packages or +# directory trees. +# +# This recursively checks the "deps" attribute of each target, dependencies +# expressed other ways are not checked. For example, protobuf targets pull in +# protobuf code, but aren't analyzed by deps_test. +deps_test = rule( + implementation = _deps_test_impl, + attrs = { + "targets": attr.label_list( + doc = "The targets to check the transitive dependencies of.", + aspects = [_deps_check], + ), + "allowed": attr.label_list( + doc = "The allowed dependency targets.", + ), + "allowed_prefixes": attr.string_list( + doc = "Any packages beginning with these prefixes are allowed.", + ), + }, + test = True, +) diff --git a/tools/github/BUILD b/tools/github/BUILD index 7d0a179f7..a345debf6 100644 --- a/tools/github/BUILD +++ b/tools/github/BUILD @@ -7,7 +7,6 @@ go_binary( srcs = ["main.go"], nogo = False, deps = [ - "//tools/github/nogo", "//tools/github/reviver", "@com_github_google_go_github_v32//github:go_default_library", "@org_golang_x_oauth2//:go_default_library", diff --git a/tools/github/main.go b/tools/github/main.go index 681003eef..dfb4c769d 100644 --- a/tools/github/main.go +++ b/tools/github/main.go @@ -22,12 +22,10 @@ import ( "io/ioutil" "log" "os" - "os/exec" "strings" "github.com/google/go-github/github" "golang.org/x/oauth2" - "gvisor.dev/gvisor/tools/github/nogo" "gvisor.dev/gvisor/tools/github/reviver" ) @@ -53,11 +51,10 @@ func (s *stringList) Set(value string) error { // Keep the options simple for now. Supports only a single path and repo. func init() { - flag.StringVar(&owner, "owner", "", "GitHub project org/owner (required, except nogo dry-run)") - flag.StringVar(&repo, "repo", "", "GitHub repo (required, except nogo dry-run)") + flag.StringVar(&owner, "owner", "", "GitHub project org/owner") + flag.StringVar(&repo, "repo", "", "GitHub repo") flag.StringVar(&tokenFile, "oauth-token-file", "", "file containing the GitHub token (or GITHUB_TOKEN is set)") - flag.Var(&paths, "path", "path(s) to scan (required for revive and nogo)") - flag.StringVar(&commit, "commit", "", "commit to associated (required for nogo, except dry-run)") + flag.Var(&paths, "path", "path(s) to scan (required for revive)") flag.BoolVar(&dryRun, "dry-run", false, "just print changes to be made") } @@ -96,12 +93,12 @@ func main() { // Check for mandatory parameters. command := args[0] - if len(owner) == 0 && (command != "nogo" || !dryRun) { + if len(owner) == 0 { fmt.Fprintln(flag.CommandLine.Output(), "missing --owner option.") flag.Usage() os.Exit(1) } - if len(repo) == 0 && (command != "nogo" || !dryRun) { + if len(repo) == 0 { fmt.Fprintln(flag.CommandLine.Output(), "missing --repo option.") flag.Usage() os.Exit(1) @@ -155,28 +152,6 @@ func main() { } os.Exit(1) } - case "nogo": - // Did we get a commit? Try to extract one. - if len(commit) == 0 && !dryRun { - cmd := exec.Command("git", "rev-parse", "HEAD") - revBytes, err := cmd.Output() - if err != nil { - fmt.Fprintf(flag.CommandLine.Output(), "missing --commit option, unable to infer: %v\n", err) - flag.Usage() - os.Exit(1) - } - commit = strings.TrimSpace(string(revBytes)) - } - // Scan all findings. - poster := nogo.NewFindingsPoster(client, owner, repo, commit, dryRun) - if err := poster.Walk(filteredPaths); err != nil { - fmt.Fprintln(os.Stderr, "Error finding nogo findings:", err) - os.Exit(1) - } - // Post to GitHub. - if err := poster.Post(); err != nil { - fmt.Fprintln(os.Stderr, "Error posting nogo findings:", err) - } default: // Not a known command. fmt.Fprintf(flag.CommandLine.Output(), "unknown command: %s\n", command) diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD deleted file mode 100644 index 4259fe94c..000000000 --- a/tools/github/nogo/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "nogo", - srcs = ["nogo.go"], - nogo = False, - visibility = [ - "//tools/github:__subpackages__", - ], - deps = [ - "//tools/nogo", - "@com_github_google_go_github_v32//github:go_default_library", - ], -) diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go deleted file mode 100644 index 894a0e7c3..000000000 --- a/tools/github/nogo/nogo.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package nogo provides nogo-related utilities. -package nogo - -import ( - "context" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - "github.com/google/go-github/github" - "gvisor.dev/gvisor/tools/nogo" -) - -// FindingsPoster is a simple wrapper around the GitHub api. -type FindingsPoster struct { - owner string - repo string - commit string - dryRun bool - startTime time.Time - - findings map[nogo.Finding]struct{} - client *github.Client -} - -// NewFindingsPoster returns a object that can post findings. -func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun bool) *FindingsPoster { - return &FindingsPoster{ - owner: owner, - repo: repo, - commit: commit, - dryRun: dryRun, - startTime: time.Now(), - findings: make(map[nogo.Finding]struct{}), - client: client, - } -} - -// Walk walks the given path tree for findings files. -func (p *FindingsPoster) Walk(paths []string) error { - for _, path := range paths { - if err := filepath.Walk(path, func(filename string, info os.FileInfo, err error) error { - if err != nil { - return err - } - // Skip any directories or files not ending in .findings. - if !strings.HasSuffix(filename, ".findings") || info.IsDir() { - return nil - } - findings, err := nogo.ExtractFindingsFromFile(filename) - if err != nil { - return err - } - // Add all findings to the list. We use a map to ensure - // that each finding is unique. - for _, finding := range findings { - p.findings[finding] = struct{}{} - } - return nil - }); err != nil { - return err - } - } - return nil -} - -// Post posts all results to the GitHub API as a check run. -func (p *FindingsPoster) Post() error { - // Just show results? - if p.dryRun { - for finding := range p.findings { - // Pretty print, so that this is useful for debugging. - fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Position.Filename, finding.Position.Line, finding.Message) - } - return nil - } - - // Construct the message. - title := "nogo" - count := len(p.findings) - status := "completed" - conclusion := "success" - if count > 0 { - conclusion = "failure" // Contains errors. - } - summary := fmt.Sprintf("%d findings.", count) - opts := github.CreateCheckRunOptions{ - Name: title, - HeadSHA: p.commit, - Status: &status, - Conclusion: &conclusion, - StartedAt: &github.Timestamp{p.startTime}, - CompletedAt: &github.Timestamp{time.Now()}, - Output: &github.CheckRunOutput{ - Title: &title, - Summary: &summary, - AnnotationsCount: &count, - }, - } - annotationLevel := "failure" // Always. - for finding := range p.findings { - title := string(finding.Category) - opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{ - Path: &finding.Position.Filename, - StartLine: &finding.Position.Line, - EndLine: &finding.Position.Line, - Message: &finding.Message, - Title: &title, - AnnotationLevel: &annotationLevel, - }) - } - - // Post to GitHub. - _, _, err := p.client.Checks.CreateCheckRun(context.Background(), p.owner, p.repo, opts) - return err -} diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index f79defea7..85a1adf66 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -5,9 +5,7 @@ licenses(["notice"]) go_binary( name = "go_marshal", srcs = ["main.go"], - visibility = [ - "//:sandbox", - ], + visibility = ["//:sandbox"], deps = [ "//tools/go_marshal/gomarshal", ], @@ -16,6 +14,7 @@ go_binary( config_setting( name = "marshal_config_verbose", values = {"define": "gomarshal=verbose"}, + visibility = ["//:sandbox"], ) bzl_library( diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index eddba0c21..bbd4c9f48 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -140,3 +140,6 @@ options, depending on how go-marshal is being invoked: - Set `debug = True` on the `go_marshal` BUILD rule. - Pass `-debug` to the go-marshal tool invocation. + +If bazel complains about stdout output being too large, set a larger value +through `--experimental_ui_max_stdouterr_bytes`, or `-1` for unlimited output. diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index e23901815..9f620cb76 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -57,7 +57,6 @@ go_marshal = rule( # marshal_deps are the dependencies requied by generated code. marshal_deps = [ "//pkg/gohacks", - "//pkg/safecopy", "//pkg/hostarch", "//pkg/marshal", ] diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 0e2d752cb..00961c90d 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -112,10 +112,8 @@ func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, g.imports.add("runtime") g.imports.add("unsafe") g.imports.add("gvisor.dev/gvisor/pkg/gohacks") - g.imports.add("gvisor.dev/gvisor/pkg/safecopy") g.imports.add("gvisor.dev/gvisor/pkg/hostarch") g.imports.add("gvisor.dev/gvisor/pkg/marshal") - return &g, nil } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go index 32afece2e..bd7741ae5 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -33,13 +33,13 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType } func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) { + g.recordUsedImport("gohacks") + g.recordUsedImport("hostarch") g.recordUsedImport("io") g.recordUsedImport("marshal") g.recordUsedImport("reflect") g.recordUsedImport("runtime") - g.recordUsedImport("safecopy") g.recordUsedImport("unsafe") - g.recordUsedImport("hostarch") lenExpr := g.arrayLenExpr(a) @@ -89,14 +89,14 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go index 05f0e0db4..ba4b7324e 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -95,13 +95,13 @@ func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { // newtypes are always packed, so we can omit the various fallbacks required for // non-packed structs. func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { + g.recordUsedImport("gohacks") + g.recordUsedImport("hostarch") g.recordUsedImport("io") g.recordUsedImport("marshal") g.recordUsedImport("reflect") g.recordUsedImport("runtime") - g.recordUsedImport("safecopy") g.recordUsedImport("unsafe") - g.recordUsedImport("hostarch") g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") g.emit("//go:nosplit\n") @@ -141,14 +141,14 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) }) g.emit("}\n\n") @@ -260,11 +260,9 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emit("}\n") g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - g.emitNoEscapeSliceDataPointer("&src", "val") - - g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") - g.emitKeepAlive("src") - g.emit("return length, err\n") + g.emit("dst = dst[:size*count]\n") + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n") + g.emit("return size*count, nil\n") }) g.emit("}\n\n") @@ -279,11 +277,9 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emit("}\n") g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - g.emitNoEscapeSliceDataPointer("&dst", "val") - - g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") - g.emitKeepAlive("dst") - g.emit("return length, err\n") + g.emit("src = src[:(size*count)]\n") + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n") + g.emit("return size*count, nil\n") }) g.emit("}\n\n") } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 72df1ab64..4c47218f1 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -270,18 +270,18 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("%s.MarshalBytes(dst)\n", g.r) } if thisPacked { - g.recordUsedImport("safecopy") + g.recordUsedImport("gohacks") g.recordUsedImport("unsafe") if cond, ok := g.areFieldsPackedExpression(); ok { g.emit("if %s {\n", cond) g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) }) g.emit("} else {\n") g.inIndent(fallback) g.emit("}\n") } else { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) } } else { fallback() @@ -297,25 +297,23 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("%s.UnmarshalBytes(src)\n", g.r) } if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") + g.recordUsedImport("gohacks") if cond, ok := g.areFieldsPackedExpression(); ok { g.emit("if %s {\n", cond) g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) }) g.emit("} else {\n") g.inIndent(fallback) g.emit("}\n") } else { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) } } else { fallback() } }) g.emit("}\n\n") - g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") g.emit("//go:nosplit\n") g.recordUsedImport("marshal") @@ -561,16 +559,15 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.recordUsedImport("reflect") g.recordUsedImport("runtime") g.recordUsedImport("unsafe") + g.recordUsedImport("gohacks") if _, ok := g.areFieldsPackedExpression(); ok { g.emit("if !src[0].Packed() {\n") g.inIndent(fallback) g.emit("}\n\n") } - g.emitNoEscapeSliceDataPointer("&src", "val") - - g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") - g.emitKeepAlive("src") - g.emit("return length, err\n") + g.emit("dst = dst[:size*count]\n") + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n") + g.emit("return size * count, nil\n") } else { fallback() } @@ -598,19 +595,19 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("return size * count, nil\n") } if thisPacked { + g.recordUsedImport("gohacks") g.recordUsedImport("reflect") g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") if _, ok := g.areFieldsPackedExpression(); ok { g.emit("if !dst[0].Packed() {\n") g.inIndent(fallback) g.emit("}\n\n") } - g.emitNoEscapeSliceDataPointer("&dst", "val") - g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") - g.emitKeepAlive("dst") - g.emit("return length, err\n") + g.emit("src = src[:(size*count)]\n") + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n") + + g.emit("return count*size, nil\n") } else { fallback() } diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 5fc60d8d8..6c6f604b5 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -37,6 +37,7 @@ go_library( "//tools/checkescape", "//tools/checklocks", "//tools/checkunsafe", + "//tools/worker", "@co_honnef_go_tools//staticcheck:go_default_library", "@co_honnef_go_tools//stylecheck:go_default_library", "@org_golang_x_tools//go/analysis:go_default_library", diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go index 8b4bff3b6..2b3c03fec 100644 --- a/tools/nogo/analyzers.go +++ b/tools/nogo/analyzers.go @@ -83,11 +83,6 @@ var AllAnalyzers = []*analysis.Analyzer{ checklocks.Analyzer, } -// EscapeAnalyzers is a list of escape-related analyzers. -var EscapeAnalyzers = []*analysis.Analyzer{ - checkescape.EscapeAnalyzer, -} - func register(all []*analysis.Analyzer) { // Register all fact types. // @@ -129,5 +124,4 @@ func init() { // Register lists. register(AllAnalyzers) - register(EscapeAnalyzers) } diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD index e18483a18..666780dd3 100644 --- a/tools/nogo/check/BUILD +++ b/tools/nogo/check/BUILD @@ -7,5 +7,8 @@ go_binary( srcs = ["main.go"], nogo = False, visibility = ["//visibility:public"], - deps = ["//tools/nogo"], + deps = [ + "//tools/nogo", + "//tools/worker", + ], ) diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go index 69bdfe502..3a6c3fb08 100644 --- a/tools/nogo/check/main.go +++ b/tools/nogo/check/main.go @@ -24,6 +24,7 @@ import ( "os" "gvisor.dev/gvisor/tools/nogo" + "gvisor.dev/gvisor/tools/worker" ) var ( @@ -31,7 +32,6 @@ var ( stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") factsOutput = flag.String("facts", "", "output file for facts (optional)") - escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") ) func loadConfig(file string, config interface{}) interface{} { @@ -50,9 +50,10 @@ func loadConfig(file string, config interface{}) interface{} { } func main() { - // Parse all flags. - flag.Parse() + worker.Work(run) +} +func run([]string) int { var ( findings []nogo.Finding factData []byte @@ -66,25 +67,13 @@ func main() { // Run the configuration. if *stdlibFile != "" { - // Perform basic analysis. + // Perform stdlib analysis. c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig) findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers) - } else if *packageFile != "" { - // Perform basic analysis. + // Perform standard analysis. c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig) findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil) - - // Do we need to do escape analysis? - if *escapesOutput != "" { - escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil) - if err != nil { - log.Fatalf("error performing escape analysis: %v", err) - } - if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil { - log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err) - } - } } else { log.Fatalf("please provide at least one of package or stdlib!") } @@ -103,7 +92,11 @@ func main() { // Write all findings. if *findingsOutput != "" { - if err := nogo.WriteFindingsToFile(findings, *findingsOutput); err != nil { + w, err := os.OpenFile(*findingsOutput, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + log.Fatalf("error opening output file %q: %v", *findingsOutput, err) + } + if err := nogo.WriteFindingsTo(w, findings, false /* json */); err != nil { log.Fatalf("error writing findings to %q: %v", *findingsOutput, err) } } else { @@ -111,4 +104,6 @@ func main() { fmt.Fprintf(os.Stdout, "%s\n", finding.String()) } } + + return 0 } diff --git a/tools/nogo/config.go b/tools/nogo/config.go index 2fea5b3e1..6436f9d34 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -73,17 +73,28 @@ type ItemConfig struct { } func compileRegexps(ss []string, rs *[]*regexp.Regexp) error { - *rs = make([]*regexp.Regexp, 0, len(ss)) - for _, s := range ss { + *rs = make([]*regexp.Regexp, len(ss)) + for i, s := range ss { r, err := regexp.Compile(s) if err != nil { return err } - *rs = append(*rs, r) + (*rs)[i] = r } return nil } +// RegexpCount is used by AnalyzerConfig.RegexpCount. +func (i *ItemConfig) RegexpCount() int64 { + if i == nil { + // See compile. + return 0 + } + // Return the number of regular expressions compiled for these items. + // This is how the cache size of the configuration is measured. + return int64(len(i.exclude) + len(i.suppress)) +} + func (i *ItemConfig) compile() error { if i == nil { // This may be nil if nothing is included in the @@ -100,9 +111,25 @@ func (i *ItemConfig) compile() error { return nil } +func merge(a, b []string) []string { + found := make(map[string]struct{}) + result := make([]string, 0, len(a)+len(b)) + for _, elem := range a { + found[elem] = struct{}{} + result = append(result, elem) + } + for _, elem := range b { + if _, ok := found[elem]; ok { + continue + } + result = append(result, elem) + } + return result +} + func (i *ItemConfig) merge(other *ItemConfig) { - i.Exclude = append(i.Exclude, other.Exclude...) - i.Suppress = append(i.Suppress, other.Suppress...) + i.Exclude = merge(i.Exclude, other.Exclude) + i.Suppress = merge(i.Suppress, other.Suppress) } func (i *ItemConfig) shouldReport(fullPos, msg string) bool { @@ -129,6 +156,15 @@ func (i *ItemConfig) shouldReport(fullPos, msg string) bool { // configurations depending on what Group the file belongs to. type AnalyzerConfig map[GroupName]*ItemConfig +// RegexpCount is used by Config.Size. +func (a AnalyzerConfig) RegexpCount() int64 { + count := int64(0) + for _, gc := range a { + count += gc.RegexpCount() + } + return count +} + func (a AnalyzerConfig) compile() error { for name, gc := range a { if err := gc.compile(); err != nil { @@ -179,22 +215,36 @@ type Config struct { Analyzers map[AnalyzerName]AnalyzerConfig `yaml:"analyzers"` } +// Size implements worker.Sizer.Size. +func (c *Config) Size() int64 { + count := c.Global.RegexpCount() + for _, config := range c.Analyzers { + count += config.RegexpCount() + } + // The size is measured as the number of regexps that are compiled + // here. We multiply by 1k to produce an estimate. + return 1024 * count +} + // Merge merges two configurations. func (c *Config) Merge(other *Config) { // Merge all groups. + // + // Select the other first, as the order provided in the second will + // provide precendence over the same group defined in the first one. + seenGroups := make(map[GroupName]struct{}) + newGroups := make([]Group, 0, len(c.Groups)+len(other.Groups)) for _, g := range other.Groups { - // Is there a matching group? If yes, we just delete - // it. This will preserve the order provided in the - // overriding file, even if it differs. - for i := 0; i < len(c.Groups); i++ { - if g.Name == c.Groups[i].Name { - copy(c.Groups[i:], c.Groups[i+1:]) - c.Groups = c.Groups[:len(c.Groups)-1] - break - } + newGroups = append(newGroups, g) + seenGroups[g.Name] = struct{}{} + } + for _, g := range c.Groups { + if _, ok := seenGroups[g.Name]; ok { + continue } - c.Groups = append(c.Groups, g) + newGroups = append(newGroups, g) } + c.Groups = newGroups // Merge global configurations. c.Global.merge(other.Global) diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 0c48a7a5a..ddf5816a6 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -29,6 +29,7 @@ NogoTargetInfo = provider( fields = { "goarch": "the build architecture (GOARCH)", "goos": "the build OS target (GOOS)", + "worker_debug": "transitive debugging", }, ) @@ -36,6 +37,7 @@ def _nogo_target_impl(ctx): return [NogoTargetInfo( goarch = ctx.attr.goarch, goos = ctx.attr.goos, + worker_debug = ctx.attr.worker_debug, )] nogo_target = go_rule( @@ -50,6 +52,10 @@ nogo_target = go_rule( doc = "the Go OS target (propagated to other rules).", mandatory = True, ), + "worker_debug": attr.bool( + doc = "whether worker debugging should be enabled.", + default = False, + ), }, ) @@ -61,7 +67,7 @@ def _nogo_objdump_tool_impl(ctx): # we need the tool to handle this case by creating a temporary file. # # [1] https://github.com/golang/go/issues/41051 - nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] + nogo_target_info = ctx.attr._target[NogoTargetInfo] go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()]) dumper = ctx.actions.declare_file(ctx.label.name) @@ -94,7 +100,7 @@ nogo_objdump_tool = go_rule( rule, implementation = _nogo_objdump_tool_impl, attrs = { - "_nogo_target": attr.label( + "_target": attr.label( default = "//tools/nogo:target", cfg = "target", ), @@ -112,7 +118,7 @@ NogoStdlibInfo = provider( def _nogo_stdlib_impl(ctx): # Build the standard library facts. - nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] + nogo_target_info = ctx.attr._target[NogoTargetInfo] go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(ctx.label.name + ".facts") raw_findings = ctx.actions.declare_file(ctx.label.name + ".raw_findings") @@ -120,22 +126,33 @@ def _nogo_stdlib_impl(ctx): Srcs = [f.path for f in go_ctx.stdlib_srcs], GOOS = go_ctx.goos, GOARCH = go_ctx.goarch, - Tags = go_ctx.tags, + Tags = go_ctx.gotags, ) config_file = ctx.actions.declare_file(ctx.label.name + ".cfg") ctx.actions.write(config_file, config.to_json()) - ctx.actions.run( - inputs = [config_file] + go_ctx.stdlib_srcs, - outputs = [facts, raw_findings], - tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), - executable = ctx.files._nogo_check[0], - mnemonic = "NogoStandardLibraryAnalysis", - progress_message = "Analyzing Go Standard Library", - arguments = go_ctx.nogo_args + [ - "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, + args_file = ctx.actions.declare_file(ctx.label.name + "_args_file") + ctx.actions.write( + output = args_file, + content = "\n".join(go_ctx.nogo_args + [ + "-objdump_tool=%s" % ctx.files._objdump_tool[0].path, "-stdlib=%s" % config_file.path, "-findings=%s" % raw_findings.path, "-facts=%s" % facts.path, + ]), + ) + ctx.actions.run( + inputs = [config_file] + go_ctx.stdlib_srcs + [args_file], + outputs = [facts, raw_findings], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool), + executable = ctx.files._check[0], + mnemonic = "GoStandardLibraryAnalysis", + # Note that this does not support work execution currently. There is an + # issue with stdout pollution that is not yet resolved, so this is kept + # as a separate menomic. + progress_message = "Analyzing Go Standard Library", + arguments = [ + "--worker_debug=%s" % nogo_target_info.worker_debug, + "@%s" % args_file.path, ], ) @@ -149,15 +166,15 @@ nogo_stdlib = go_rule( rule, implementation = _nogo_stdlib_impl, attrs = { - "_nogo_check": attr.label( + "_check": attr.label( default = "//tools/nogo/check:check", cfg = "host", ), - "_nogo_objdump_tool": attr.label( + "_objdump_tool": attr.label( default = "//tools/nogo:objdump_tool", cfg = "host", ), - "_nogo_target": attr.label( + "_target": attr.label( default = "//tools/nogo:target", cfg = "target", ), @@ -174,7 +191,6 @@ NogoInfo = provider( fields = { "facts": "serialized package facts", "raw_findings": "raw package findings (if relevant)", - "escapes": "escape-only findings (if relevant)", "importpath": "package import path", "binaries": "package binary files", "srcs": "srcs (for go_test support)", @@ -277,18 +293,17 @@ def _nogo_aspect_impl(target, ctx): inputs.append(stdlib_facts) # The nogo tool operates on a configuration serialized in JSON format. - nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] + nogo_target_info = ctx.attr._target[NogoTargetInfo] go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(target.label.name + ".facts") raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings") - escapes = ctx.actions.declare_file(target.label.name + ".escapes") config = struct( ImportPath = importpath, GoFiles = [src.path for src in srcs if src.path.endswith(".go")], NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], GOOS = go_ctx.goos, GOARCH = go_ctx.goarch, - Tags = go_ctx.tags, + Tags = go_ctx.gotags, FactMap = fact_map, ImportMap = import_map, StdlibFacts = stdlib_facts.path, @@ -296,20 +311,28 @@ def _nogo_aspect_impl(target, ctx): config_file = ctx.actions.declare_file(target.label.name + ".cfg") ctx.actions.write(config_file, config.to_json()) inputs.append(config_file) - ctx.actions.run( - inputs = inputs, - outputs = [facts, raw_findings, escapes], - tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), - executable = ctx.files._nogo_check[0], - mnemonic = "NogoAnalysis", - progress_message = "Analyzing %s" % target.label, - arguments = go_ctx.nogo_args + [ + args_file = ctx.actions.declare_file(ctx.label.name + "_args_file") + ctx.actions.write( + output = args_file, + content = "\n".join(go_ctx.nogo_args + [ "-binary=%s" % target_objfile.path, - "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, + "-objdump_tool=%s" % ctx.files._objdump_tool[0].path, "-package=%s" % config_file.path, "-findings=%s" % raw_findings.path, "-facts=%s" % facts.path, - "-escapes=%s" % escapes.path, + ]), + ) + ctx.actions.run( + inputs = inputs + [args_file], + outputs = [facts, raw_findings], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool), + executable = ctx.files._check[0], + mnemonic = "GoStaticAnalysis", + progress_message = "Analyzing %s" % target.label, + execution_requirements = {"supports-workers": "1"}, + arguments = [ + "--worker_debug=%s" % nogo_target_info.worker_debug, + "@%s" % args_file.path, ], ) @@ -322,15 +345,16 @@ def _nogo_aspect_impl(target, ctx): all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings] # Return the package facts as output. - return [NogoInfo( - facts = facts, - raw_findings = all_raw_findings, - escapes = escapes, - importpath = importpath, - binaries = binaries, - srcs = srcs, - deps = deps, - )] + return [ + NogoInfo( + facts = facts, + raw_findings = all_raw_findings, + importpath = importpath, + binaries = binaries, + srcs = srcs, + deps = deps, + ), + ] nogo_aspect = go_rule( aspect, @@ -341,47 +365,60 @@ nogo_aspect = go_rule( "embed", ], attrs = { - "_nogo_check": attr.label( + "_check": attr.label( default = "//tools/nogo/check:check", cfg = "host", ), - "_nogo_stdlib": attr.label( - default = "//tools/nogo:stdlib", - cfg = "host", - ), - "_nogo_objdump_tool": attr.label( + "_objdump_tool": attr.label( default = "//tools/nogo:objdump_tool", cfg = "host", ), - "_nogo_target": attr.label( + "_target": attr.label( default = "//tools/nogo:target", cfg = "target", ), + # The name of this attribute must not be _stdlib, since that + # appears to be reserved for some internal bazel use. + "_nogo_stdlib": attr.label( + default = "//tools/nogo:stdlib", + cfg = "host", + ), }, ) def _nogo_test_impl(ctx): """Check nogo findings.""" + nogo_target_info = ctx.attr._target[NogoTargetInfo] # Ensure there's a single dependency. if len(ctx.attr.deps) != 1: fail("nogo_test requires exactly one dep.") raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings - escapes = ctx.attr.deps[0][NogoInfo].escapes # Build a step that applies the configuration. config_srcs = ctx.attr.config[NogoConfigInfo].srcs findings = ctx.actions.declare_file(ctx.label.name + ".findings") + args_file = ctx.actions.declare_file(ctx.label.name + "_args_file") + ctx.actions.write( + output = args_file, + content = "\n".join( + ["-input=%s" % f.path for f in raw_findings] + + ["-config=%s" % f.path for f in config_srcs] + + ["-output=%s" % findings.path], + ), + ) ctx.actions.run( - inputs = raw_findings + ctx.files.srcs + config_srcs, + inputs = raw_findings + ctx.files.srcs + config_srcs + [args_file], outputs = [findings], tools = depset(ctx.files._filter), executable = ctx.files._filter[0], mnemonic = "GoStaticAnalysis", progress_message = "Generating %s" % ctx.label, - arguments = ["-input=%s" % f.path for f in raw_findings] + - ["-config=%s" % f.path for f in config_srcs] + - ["-output=%s" % findings.path], + execution_requirements = {"supports-workers": "1"}, + arguments = [ + "--worker_debug=%s" % nogo_target_info.worker_debug, + "@%s" % args_file.path, + ], ) # Build a runner that checks the filtered facts. @@ -392,7 +429,7 @@ def _nogo_test_impl(ctx): runner = ctx.actions.declare_file(ctx.label.name) runner_content = [ "#!/bin/bash", - "exec %s -input=%s" % (ctx.files._filter[0].short_path, findings.short_path), + "exec %s -check -input=%s" % (ctx.files._filter[0].short_path, findings.short_path), "", ] ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) @@ -409,8 +446,6 @@ def _nogo_test_impl(ctx): # pays attention to the mnemoic above, so this must be # what is expected by the tooling. nogo_findings = depset([findings]), - # Expose all escape analysis findings (see above). - nogo_escapes = depset([escapes]), )] nogo_test = rule( @@ -428,7 +463,26 @@ nogo_test = rule( allow_files = True, doc = "Relevant src files. This is ignored except to make the nogo_test directly affected by the files.", ), + "_target": attr.label( + default = "//tools/nogo:target", + cfg = "target", + ), "_filter": attr.label(default = "//tools/nogo/filter:filter"), }, test = True, ) + +def _nogo_aspect_tricorder_impl(target, ctx): + if ctx.rule.kind != "nogo_test" or OutputGroupInfo not in target: + return [] + if not hasattr(target[OutputGroupInfo], "nogo_findings"): + return [] + return [ + OutputGroupInfo(tricorder = target[OutputGroupInfo].nogo_findings), + ] + +# Trivial aspect that forwards the findings from a nogo_test rule to +# go/tricorder, which reads from the `tricorder` output group. +nogo_aspect_tricorder = aspect( + implementation = _nogo_aspect_tricorder_impl, +) diff --git a/tools/nogo/filter/BUILD b/tools/nogo/filter/BUILD index e56a783e2..e3049521e 100644 --- a/tools/nogo/filter/BUILD +++ b/tools/nogo/filter/BUILD @@ -9,6 +9,7 @@ go_binary( visibility = ["//visibility:public"], deps = [ "//tools/nogo", + "//tools/worker", "@in_gopkg_yaml_v2//:go_default_library", ], ) diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go index 8be38ca6d..d50336b9b 100644 --- a/tools/nogo/filter/main.go +++ b/tools/nogo/filter/main.go @@ -26,6 +26,7 @@ import ( yaml "gopkg.in/yaml.v2" "gvisor.dev/gvisor/tools/nogo" + "gvisor.dev/gvisor/tools/worker" ) type stringList []string @@ -44,34 +45,44 @@ var ( configFiles stringList outputFile string showConfig bool + check bool ) func init() { - flag.Var(&inputFiles, "input", "findings input files") - flag.StringVar(&outputFile, "output", "", "findings output file") + flag.Var(&inputFiles, "input", "findings input files (gob format)") + flag.StringVar(&outputFile, "output", "", "findings output file (json format)") flag.Var(&configFiles, "config", "findings configuration files") flag.BoolVar(&showConfig, "show-config", false, "dump configuration only") + flag.BoolVar(&check, "check", false, "assume input is in json format") } func main() { - flag.Parse() + worker.Work(run) +} - // Load all available findings. - var findings []nogo.Finding - for _, filename := range inputFiles { - inputFindings, err := nogo.ExtractFindingsFromFile(filename) +var ( + cachedFindings = worker.NewCache("findings") // With nogo.FindingSet. + cachedFiltered = worker.NewCache("filtered") // With nogo.FindingSet. + cachedConfigs = worker.NewCache("configs") // With nogo.Config. + cachedFullConfigs = worker.NewCache("compiled") // With nogo.Config. +) + +func loadFindings(filename string) nogo.FindingSet { + return cachedFindings.Lookup([]string{filename}, func() worker.Sizer { + r, err := os.Open(filename) + if err != nil { + log.Fatalf("unable to open input %q: %v", filename, err) + } + inputFindings, err := nogo.ExtractFindingsFrom(r, check /* json */) if err != nil { log.Fatalf("unable to extract findings from %s: %v", filename, err) } - findings = append(findings, inputFindings...) - } + return inputFindings + }).(nogo.FindingSet) +} - // Open and merge all configuations. - config := &nogo.Config{ - Global: make(nogo.AnalyzerConfig), - Analyzers: make(map[nogo.AnalyzerName]nogo.AnalyzerConfig), - } - for _, filename := range configFiles { +func loadConfig(filename string) *nogo.Config { + return cachedConfigs.Lookup([]string{filename}, func() worker.Sizer { content, err := ioutil.ReadFile(filename) if err != nil { log.Fatalf("unable to read %s: %v", filename, err) @@ -82,53 +93,98 @@ func main() { if err := dec.Decode(&newConfig); err != nil { log.Fatalf("unable to decode %s: %v", filename, err) } - config.Merge(&newConfig) if showConfig { content, err := yaml.Marshal(&newConfig) if err != nil { log.Fatalf("error marshalling config: %v", err) } - mergedBytes, err := yaml.Marshal(config) - if err != nil { - log.Fatalf("error marshalling config: %v", err) - } fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(content)) - fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes)) } - } - if err := config.Compile(); err != nil { - log.Fatalf("error compiling config: %v", err) - } + return &newConfig + }).(*nogo.Config) +} + +func loadConfigs(filenames []string) *nogo.Config { + return cachedFullConfigs.Lookup(filenames, func() worker.Sizer { + config := &nogo.Config{ + Global: make(nogo.AnalyzerConfig), + Analyzers: make(map[nogo.AnalyzerName]nogo.AnalyzerConfig), + } + for _, filename := range configFiles { + config.Merge(loadConfig(filename)) + if showConfig { + mergedBytes, err := yaml.Marshal(config) + if err != nil { + log.Fatalf("error marshalling config: %v", err) + } + fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes)) + } + } + if err := config.Compile(); err != nil { + log.Fatalf("error compiling config: %v", err) + } + return config + }).(*nogo.Config) +} + +func run([]string) int { + // Open and merge all configuations. + config := loadConfigs(configFiles) if showConfig { - os.Exit(0) + return 0 } - // Filter the findings (and aggregate by group). - filteredFindings := make([]nogo.Finding, 0, len(findings)) - for _, finding := range findings { - if ok := config.ShouldReport(finding); ok { - filteredFindings = append(filteredFindings, finding) - } + // Load and filer available findings. + var filteredFindings []nogo.Finding + for _, filename := range inputFiles { + // Note that this applies a caching strategy to the filtered + // findings, because *this is by far the most expensive part of + // evaluation*. The set of findings is large and applying the + // configuration is complex. Therefore, we segment this cache + // on each individual raw findings input file and the + // configuration files. Note that this cache is keyed on all + // the configuration files and each individual raw findings, so + // is guaranteed to be safe. This allows us to reuse the same + // filter result many times over, because e.g. all standard + // library findings will be available to all packages. + filteredFindings = append(filteredFindings, + cachedFiltered.Lookup(append(configFiles, filename), func() worker.Sizer { + inputFindings := loadFindings(filename) + filteredFindings := make(nogo.FindingSet, 0, len(inputFindings)) + for _, finding := range inputFindings { + if ok := config.ShouldReport(finding); ok { + filteredFindings = append(filteredFindings, finding) + } + } + return filteredFindings + }).(nogo.FindingSet)...) } // Write the output (if required). // // If the outputFile is specified, then we exit here. Otherwise, // we continue to write to stdout and treat like a test. + // + // Note that the output of the filter is always json, which is + // human readable and the format that is consumed by tricorder. if outputFile != "" { - if err := nogo.WriteFindingsToFile(filteredFindings, outputFile); err != nil { + w, err := os.OpenFile(outputFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + log.Fatalf("unable to open output file %q: %v", outputFile, err) + } + if err := nogo.WriteFindingsTo(w, filteredFindings, true /* json */); err != nil { log.Fatalf("unable to write findings: %v", err) } - return + return 0 } // Treat the run as a test. if len(filteredFindings) == 0 { fmt.Fprintf(os.Stdout, "PASS\n") - os.Exit(0) + return 0 } for _, finding := range filteredFindings { fmt.Fprintf(os.Stdout, "%s\n", finding.String()) } - os.Exit(1) + return 1 } diff --git a/tools/nogo/findings.go b/tools/nogo/findings.go index 5bd850269..329a7062e 100644 --- a/tools/nogo/findings.go +++ b/tools/nogo/findings.go @@ -15,10 +15,14 @@ package nogo import ( + "encoding/gob" "encoding/json" "fmt" "go/token" - "io/ioutil" + "io" + "os" + "reflect" + "sort" ) // Finding is a single finding. @@ -28,36 +32,96 @@ type Finding struct { Message string } +// findingSize is the size of the finding struct itself. +var findingSize = int64(reflect.TypeOf(Finding{}).Size()) + +// Size implements worker.Sizer.Size. +func (f *Finding) Size() int64 { + return int64(len(f.Category)) + int64(len(f.Message)) + findingSize +} + // String implements fmt.Stringer.String. func (f *Finding) String() string { return fmt.Sprintf("%s: %s: %s", f.Category, f.Position.String(), f.Message) } -// WriteFindingsToFile writes findings to a file. -func WriteFindingsToFile(findings []Finding, filename string) error { - content, err := WriteFindingsToBytes(findings) - if err != nil { - return err +// FindingSet is a collection of findings. +type FindingSet []Finding + +// Size implmements worker.Sizer.Size. +func (fs FindingSet) Size() int64 { + size := int64(0) + for _, finding := range fs { + size += finding.Size() } - return ioutil.WriteFile(filename, content, 0644) + return size } -// WriteFindingsToBytes serializes findings as bytes. -func WriteFindingsToBytes(findings []Finding) ([]byte, error) { - return json.Marshal(findings) +// Sort sorts all findings. +func (fs FindingSet) Sort() { + sort.Slice(fs, func(i, j int) bool { + switch { + case fs[i].Position.Filename < fs[j].Position.Filename: + return true + case fs[i].Position.Filename > fs[j].Position.Filename: + return false + case fs[i].Position.Line < fs[j].Position.Line: + return true + case fs[i].Position.Line > fs[j].Position.Line: + return false + case fs[i].Position.Column < fs[j].Position.Column: + return true + case fs[i].Position.Column > fs[j].Position.Column: + return false + case fs[i].Category < fs[j].Category: + return true + case fs[i].Category > fs[j].Category: + return false + case fs[i].Message < fs[j].Message: + return true + case fs[i].Message > fs[j].Message: + return false + default: + return false + } + }) +} + +// WriteFindingsTo serializes findings. +func WriteFindingsTo(w io.Writer, findings FindingSet, asJSON bool) error { + // N.B. Sort all the findings in order to maximize cacheability. + findings.Sort() + if asJSON { + enc := json.NewEncoder(w) + return enc.Encode(findings) + } + enc := gob.NewEncoder(w) + return enc.Encode(findings) } // ExtractFindingsFromFile loads findings from a file. -func ExtractFindingsFromFile(filename string) ([]Finding, error) { - content, err := ioutil.ReadFile(filename) +func ExtractFindingsFromFile(filename string, asJSON bool) (FindingSet, error) { + r, err := os.Open(filename) if err != nil { return nil, err } - return ExtractFindingsFromBytes(content) + defer r.Close() + return ExtractFindingsFrom(r, asJSON) } // ExtractFindingsFromBytes loads findings from bytes. -func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) { - err = json.Unmarshal(content, &findings) +func ExtractFindingsFrom(r io.Reader, asJSON bool) (findings FindingSet, err error) { + if asJSON { + dec := json.NewDecoder(r) + err = dec.Decode(&findings) + } else { + dec := gob.NewDecoder(r) + err = dec.Decode(&findings) + } return findings, err } + +func init() { + gob.Register((*Finding)(nil)) + gob.Register((*FindingSet)(nil)) +} diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index 779d4d6d8..acee7c8bc 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -19,7 +19,8 @@ package nogo import ( - "encoding/json" + "bytes" + "encoding/gob" "errors" "fmt" "go/ast" @@ -34,6 +35,7 @@ import ( "path" "path/filepath" "reflect" + "sort" "strings" "golang.org/x/tools/go/analysis" @@ -42,6 +44,7 @@ import ( // Special case: flags live here and change overall behavior. "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/worker" ) // StdlibConfig is serialized as the configuration. @@ -75,39 +78,100 @@ type loader func(string) ([]byte, error) // saver is a fact-saver function. type saver func([]byte) error -// factLoader returns a function that loads facts. -// -// This resolves all standard library facts and imported package facts up -// front. The returned loader function will never return an error, only -// empty facts. -// -// This is done because all stdlib data is stored together, and we don't want -// to load this data many times over. -func (c *PackageConfig) factLoader() (loader, error) { - allFacts := make(map[string][]byte) - if c.StdlibFacts != "" { - data, err := ioutil.ReadFile(c.StdlibFacts) - if err != nil { - return nil, fmt.Errorf("error loading stdlib facts from %q: %w", c.StdlibFacts, err) - } - var stdlibFacts map[string][]byte - if err := json.Unmarshal(data, &stdlibFacts); err != nil { - return nil, fmt.Errorf("error loading stdlib facts: %w", err) - } - for pkg, data := range stdlibFacts { - allFacts[pkg] = data +// stdlibFact is used for serialiation. +type stdlibFact struct { + Package string + Facts []byte +} + +// stdlibFacts is a set of standard library facts. +type stdlibFacts map[string][]byte + +// Size implements worker.Sizer.Size. +func (sf stdlibFacts) Size() int64 { + size := int64(0) + for filename, data := range sf { + size += int64(len(filename)) + size += int64(len(data)) + } + return size +} + +// EncodeTo serializes stdlibFacts. +func (sf stdlibFacts) EncodeTo(w io.Writer) error { + stdlibFactsSorted := make([]stdlibFact, 0, len(sf)) + for pkg, facts := range sf { + stdlibFactsSorted = append(stdlibFactsSorted, stdlibFact{ + Package: pkg, + Facts: facts, + }) + } + sort.Slice(stdlibFactsSorted, func(i, j int) bool { + return stdlibFactsSorted[i].Package < stdlibFactsSorted[j].Package + }) + enc := gob.NewEncoder(w) + if err := enc.Encode(stdlibFactsSorted); err != nil { + return err + } + return nil +} + +// DecodeFrom deserializes stdlibFacts. +func (sf stdlibFacts) DecodeFrom(r io.Reader) error { + var stdlibFactsSorted []stdlibFact + dec := gob.NewDecoder(r) + if err := dec.Decode(&stdlibFactsSorted); err != nil { + return err + } + for _, stdlibFact := range stdlibFactsSorted { + sf[stdlibFact.Package] = stdlibFact.Facts + } + return nil +} + +var ( + // cachedFacts caches by file (just byte data). + cachedFacts = worker.NewCache("facts") + + // stdlibCachedFacts caches the standard library (stdlibFacts). + stdlibCachedFacts = worker.NewCache("stdlib") +) + +// factLoader loads facts. +func (c *PackageConfig) factLoader(path string) (data []byte, err error) { + filename, ok := c.FactMap[path] + if ok { + cb := cachedFacts.Lookup([]string{filename}, func() worker.Sizer { + data, readErr := ioutil.ReadFile(filename) + if readErr != nil { + err = fmt.Errorf("error loading %q: %w", filename, readErr) + return nil + } + return worker.CacheBytes(data) + }) + if cb != nil { + return []byte(cb.(worker.CacheBytes)), err } + return nil, err } - for pkg, file := range c.FactMap { - data, err := ioutil.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("error loading %q: %w", file, err) + cb := stdlibCachedFacts.Lookup([]string{c.StdlibFacts}, func() worker.Sizer { + r, openErr := os.Open(c.StdlibFacts) + if openErr != nil { + err = fmt.Errorf("error loading stdlib facts from %q: %w", c.StdlibFacts, openErr) + return nil + } + defer r.Close() + sf := make(stdlibFacts) + if readErr := sf.DecodeFrom(r); readErr != nil { + err = fmt.Errorf("error loading stdlib facts: %w", readErr) + return nil } - allFacts[pkg] = data + return sf + }) + if cb != nil { + return (cb.(stdlibFacts))[path], err } - return func(path string) ([]byte, error) { - return allFacts[path], nil - }, nil + return nil, err } // shouldInclude indicates whether the file should be included. @@ -191,7 +255,7 @@ var ErrSkip = errors.New("skipped") // // Note that not all parts of the source are expected to build. We skip obvious // test files, and cmd files, which should not be dependencies. -func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindings []Finding, facts []byte, err error) { +func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindings FindingSet, facts []byte, err error) { if len(config.Srcs) == 0 { return nil, nil, nil } @@ -261,16 +325,16 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi } // Closure to check a single package. - stdlibFacts := make(map[string][]byte) - stdlibErrs := make(map[string]error) + localStdlibFacts := make(stdlibFacts) + localStdlibErrs := make(map[string]error) var checkOne func(pkg string) error // Recursive. checkOne = func(pkg string) error { // Is this already done? - if _, ok := stdlibFacts[pkg]; ok { + if _, ok := localStdlibFacts[pkg]; ok { return nil } // Did this fail previously? - if _, ok := stdlibErrs[pkg]; ok { + if _, ok := localStdlibErrs[pkg]; ok { return nil } @@ -286,7 +350,7 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi // If there's no binary for this package, it is likely // not built with the distribution. That's fine, we can // just skip analysis. - stdlibErrs[pkg] = err + localStdlibErrs[pkg] = err return nil } @@ -303,10 +367,10 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi if err != nil { // If we can't analyze a package from the standard library, // then we skip it. It will simply not have any findings. - stdlibErrs[pkg] = err + localStdlibErrs[pkg] = err return nil } - stdlibFacts[pkg] = factData + localStdlibFacts[pkg] = factData allFindings = append(allFindings, findings...) return nil } @@ -323,23 +387,23 @@ func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindi } // Sanity check. - if len(stdlibFacts) == 0 { + if len(localStdlibFacts) == 0 { return nil, nil, fmt.Errorf("no stdlib facts found: misconfiguration?") } // Write out all findings. - factData, err := json.Marshal(stdlibFacts) - if err != nil { - return nil, nil, fmt.Errorf("error saving stdlib facts: %w", err) + buf := bytes.NewBuffer(nil) + if err := localStdlibFacts.EncodeTo(buf); err != nil { + return nil, nil, fmt.Errorf("error serialized stdlib facts: %v", err) } // Write out all errors. - for pkg, err := range stdlibErrs { + for pkg, err := range localStdlibErrs { log.Printf("WARNING: error while processing %v: %v", pkg, err) } // Return all findings. - return allFindings, factData, nil + return allFindings, buf.Bytes(), nil } // CheckPackage runs all given analyzers. @@ -392,11 +456,7 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC } // Load all package facts. - loader, err := config.factLoader() - if err != nil { - return nil, nil, fmt.Errorf("error loading facts: %w", err) - } - facts, err := facts.Decode(types, loader) + facts, err := facts.Decode(types, config.factLoader) if err != nil { return nil, nil, fmt.Errorf("error decoding facts: %w", err) } @@ -471,3 +531,7 @@ func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importC // Return all findings. return findings, facts.Encode(), nil } + +func init() { + gob.Register((*stdlibFact)(nil)) +} diff --git a/tools/worker/BUILD b/tools/worker/BUILD new file mode 100644 index 000000000..dc03ce11e --- /dev/null +++ b/tools/worker/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "bazel_worker_proto", "go_library") + +package(licenses = ["notice"]) + +# For Google-tooling. +# @unused +glaze_ignore = [ + "worker.go", +] + +go_library( + name = "worker", + srcs = ["worker.go"], + visibility = ["//tools:__subpackages__"], + deps = [ + bazel_worker_proto, + "@org_golang_google_protobuf//encoding/protowire:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/tools/worker/worker.go b/tools/worker/worker.go new file mode 100644 index 000000000..669a5f203 --- /dev/null +++ b/tools/worker/worker.go @@ -0,0 +1,325 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package worker provides an implementation of the bazel worker protocol. +// +// Tools may be written as a normal command line utility, except the passed +// run function may be invoked multiple times. +package worker + +import ( + "bufio" + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + "time" + + _ "net/http/pprof" // For profiling. + + "golang.org/x/sys/unix" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + wpb "gvisor.dev/bazel/worker_protocol_go_proto" +) + +var ( + persistentWorker = flag.Bool("persistent_worker", false, "enable persistent worker.") + workerDebug = flag.Bool("worker_debug", false, "debug persistent workers.") + maximumCacheUsage = flag.Int64("maximum_cache_usage", 1024*1024*1024, "maximum cache size.") +) + +var ( + // inputFiles is the last set of input files. + // + // This is used for cache invalidation. The key is the *absolute* path + // name, and the value is the digest in the current run. + inputFiles = make(map[string]string) + + // activeCaches is the set of active caches. + activeCaches = make(map[*Cache]struct{}) + + // totalCacheUsage is the total usage of all caches. + totalCacheUsage int64 +) + +// mustAbs returns the absolute path of a filename or dies. +func mustAbs(filename string) string { + abs, err := filepath.Abs(filename) + if err != nil { + log.Fatalf("error getting absolute path: %v", err) + } + return abs +} + +// updateInputFiles creates an entry in inputFiles. +func updateInputFile(filename, digest string) { + inputFiles[mustAbs(filename)] = digest +} + +// Sizer returns a size. +type Sizer interface { + Size() int64 +} + +// CacheBytes is an example of a Sizer. +type CacheBytes []byte + +// Size implements Sizer.Size. +func (cb CacheBytes) Size() int64 { + return int64(len(cb)) +} + +// Cache is a worker cache. +// +// They can be created via NewCache. +type Cache struct { + name string + entries map[string]Sizer + size int64 + hits int64 + misses int64 +} + +// NewCache returns a new cache. +func NewCache(name string) *Cache { + return &Cache{ + name: name, + } +} + +// Lookup looks up an entry in the cache. +// +// It is a function of the given files. +func (c *Cache) Lookup(filenames []string, generate func() Sizer) Sizer { + digests := make([]string, 0, len(filenames)) + for _, filename := range filenames { + digest, ok := inputFiles[mustAbs(filename)] + if !ok { + // This is not a valid input. We may not be running as + // persistent worker in this cache. If that's the case, + // then the file's contents will not change across the + // run, and we just use the filename itself. + digest = filename + } + digests = append(digests, digest) + } + + // Attempt the lookup. + sort.Slice(digests, func(i, j int) bool { + return digests[i] < digests[j] + }) + cacheKey := strings.Join(digests, "+") + if c.entries == nil { + c.entries = make(map[string]Sizer) + activeCaches[c] = struct{}{} + } + entry, ok := c.entries[cacheKey] + if ok { + c.hits++ + return entry + } + + // Generate a new entry. + entry = generate() + c.misses++ + c.entries[cacheKey] = entry + if entry != nil { + sz := entry.Size() + c.size += sz + totalCacheUsage += sz + } + + // Check the capacity of all caches. If it greater than the maximum, we + // flush everything but still return this entry. + if totalCacheUsage > *maximumCacheUsage { + for entry, _ := range activeCaches { + // Drop all entries. + entry.size = 0 + entry.entries = nil + } + totalCacheUsage = 0 // Reset. + } + + return entry +} + +// allCacheStats returns stats for all caches. +func allCacheStats() string { + var sb strings.Builder + for entry, _ := range activeCaches { + ratio := float64(entry.hits) / float64(entry.hits+entry.misses) + fmt.Fprintf(&sb, + "% 10s: count: % 5d size: % 10d hits: % 7d misses: % 7d ratio: %2.2f\n", + entry.name, len(entry.entries), entry.size, entry.hits, entry.misses, ratio) + } + if len(activeCaches) > 0 { + fmt.Fprintf(&sb, "total: % 10d\n", totalCacheUsage) + } + return sb.String() +} + +// LookupDigest returns a digest for the given file. +func LookupDigest(filename string) (string, bool) { + digest, ok := inputFiles[filename] + return digest, ok +} + +// Work invokes the main function. +func Work(run func([]string) int) { + flag.CommandLine.Parse(os.Args[1:]) + if !*persistentWorker { + // Handle the argument file. + args := flag.CommandLine.Args() + if len(args) == 1 && len(args[0]) > 1 && args[0][0] == '@' { + content, err := ioutil.ReadFile(args[0][1:]) + if err != nil { + log.Fatalf("unable to parse args file: %v", err) + } + // Pull arguments from the file. + args = strings.Split(string(content), "\n") + flag.CommandLine.Parse(args) + args = flag.CommandLine.Args() + } + os.Exit(run(args)) + } + + var listenHeader string // Emitted always. + if *workerDebug { + // Bind a server for profiling. + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatalf("unable to bind a server: %v", err) + } + // Construct the header for stats output, below. + listenHeader = fmt.Sprintf("Listening @ http://localhost:%d\n", listener.Addr().(*net.TCPAddr).Port) + go http.Serve(listener, nil) + } + + // Move stdout. This is done to prevent anything else from accidentally + // printing to stdout, which must contain only the valid WorkerResponse + // serialized protos. + newOutput, err := unix.Dup(1) + if err != nil { + log.Fatalf("unable to move stdout: %v", err) + } + // Stderr may be closed or may be a copy of stdout. We make sure that + // we have an output that is in a completely separate range. + for newOutput <= 2 { + newOutput, err = unix.Dup(newOutput) + if err != nil { + log.Fatalf("unable to move stdout: %v", err) + } + } + + // Best-effort: collect logs. + rPipe, wPipe, err := os.Pipe() + if err != nil { + log.Fatalf("unable to create pipe: %v", err) + } + if err := unix.Dup2(int(wPipe.Fd()), 1); err != nil { + log.Fatalf("error duping over stdout: %v", err) + } + if err := unix.Dup2(int(wPipe.Fd()), 2); err != nil { + log.Fatalf("error duping over stderr: %v", err) + } + wPipe.Close() + defer rPipe.Close() + + // Read requests from stdin. + input := bufio.NewReader(os.NewFile(0, "input")) + output := bufio.NewWriter(os.NewFile(uintptr(newOutput), "output")) + for { + szBuf, err := input.Peek(4) + if err != nil { + log.Fatalf("unabel to read header: %v", err) + } + + // Parse the size, and discard bits. + sz, szBytes := protowire.ConsumeVarint(szBuf) + if szBytes < 0 { + szBytes = 0 + } + if _, err := input.Discard(szBytes); err != nil { + log.Fatalf("error discarding size: %v", err) + } + + // Read a full message. + msg := make([]byte, int(sz)) + if _, err := io.ReadFull(input, msg); err != nil { + log.Fatalf("error reading worker request: %v", err) + } + var wreq wpb.WorkRequest + if err := proto.Unmarshal(msg, &wreq); err != nil { + log.Fatalf("error unmarshaling worker request: %v", err) + } + + // Flush relevant caches. + inputFiles = make(map[string]string) + for _, input := range wreq.GetInputs() { + updateInputFile(input.GetPath(), string(input.GetDigest())) + } + + // Prepare logging. + outputBuffer := bytes.NewBuffer(nil) + outputBuffer.WriteString(listenHeader) + log.SetOutput(outputBuffer) + + // Parse all arguments. + flag.CommandLine.Parse(wreq.GetArguments()) + var exitCode int + exitChan := make(chan int) + go func() { exitChan <- run(flag.CommandLine.Args()) }() + for running := true; running; { + select { + case exitCode = <-exitChan: + running = false + default: + } + // N.B. rPipe is given a read deadline of 1ms. We expect + // this to turn a copy error after 1ms, and we just keep + // flushing this buffer while the task is running. + rPipe.SetReadDeadline(time.Now().Add(time.Millisecond)) + outputBuffer.ReadFrom(rPipe) + } + + if *workerDebug { + // Attach all cache stats. + outputBuffer.WriteString(allCacheStats()) + } + + // Send the response. + var wresp wpb.WorkResponse + wresp.ExitCode = int32(exitCode) + wresp.Output = string(outputBuffer.Bytes()) + rmsg, err := proto.Marshal(&wresp) + if err != nil { + log.Fatalf("error marshaling response: %v", err) + } + if _, err := output.Write(append(protowire.AppendVarint(nil, uint64(len(rmsg))), rmsg...)); err != nil { + log.Fatalf("error sending worker response: %v", err) + } + if err := output.Flush(); err != nil { + log.Fatalf("error flushing output: %v", err) + } + } +} diff --git a/website/BUILD b/website/BUILD index b5b3f6df6..6f52e9208 100644 --- a/website/BUILD +++ b/website/BUILD @@ -14,7 +14,7 @@ docker_image( tags = [ "local", "manual", - "nosandbox", + "no-sandbox", ], ) @@ -69,7 +69,7 @@ genrule( tags = [ "local", "manual", - "nosandbox", + "no-sandbox", ], ) |