diff options
263 files changed, 11350 insertions, 4914 deletions
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1e1677ab5..e28e46352 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,3 +21,8 @@ jobs: restore-keys: | ${{ runner.os }}-bazel- - run: make + - run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..." + - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ nogo" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REPOSITORY: ${{ github.repository }} diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml index 2b399a3f2..c53185620 100644 --- a/.github/workflows/issue_reviver.yml +++ b/.github/workflows/issue_reviver.yml @@ -9,7 +9,7 @@ jobs: steps: - uses: actions/checkout@v2 if: github.repository == 'google/gvisor' - - run: make run TARGETS="//tools/issue_reviver" + - run: make run TARGETS="//tools/github" ARGS="revive" if: github.repository == 'google/gvisor' env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -379,3 +379,8 @@ configure: ## Configures a single runtime. Requires sudo. Typically called from test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided. @$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)") .PHONY: test-runtime + +nogo: ## Surfaces all nogo findings. + @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...") + @$(call submake,run TARGETS="//tools/github" ARGS="-path=$(BUILD_ROOT) -dry-run nogo") +.PHONY: nogo diff --git a/g3doc/user_guide/networking.md b/g3doc/user_guide/networking.md index 4aa394c91..95f675633 100644 --- a/g3doc/user_guide/networking.md +++ b/g3doc/user_guide/networking.md @@ -2,9 +2,9 @@ [TOC] -gVisor implements its own network stack called [netstack][netstack]. All aspects -of the network stack are handled inside the Sentry — including TCP connection -state, control messages, and packet assembly — keeping it isolated from the host +gVisor implements its own network stack called netstack. All aspects of the +network stack are handled inside the Sentry — including TCP connection state, +control messages, and packet assembly — keeping it isolated from the host network stack. Data link layer packets are written directly to the virtual device inside the network namespace setup by Docker or Kubernetes. @@ -82,4 +82,3 @@ Offload (GSO) to run with a kernel that is newer than 3.17. Add the } ``` -[netstack]: https://github.com/google/netstack diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index cdcaa8c73..4a26e28de 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -38,6 +38,7 @@ go_library( "ipc.go", "limits.go", "linux.go", + "membarrier.go", "mm.go", "netdevice.go", "netfilter.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index dc9ac7e7c..7df02dd6d 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -121,9 +121,27 @@ const ( // Constants from uapi/linux/fsverity.h. const ( - FS_IOC_ENABLE_VERITY = 1082156677 + FS_IOC_ENABLE_VERITY = 1082156677 + FS_IOC_MEASURE_VERITY = 3221513862 ) +// DigestMetadata is a helper struct for VerityDigest. +// +// +marshal +type DigestMetadata struct { + DigestAlgorithm uint16 + DigestSize uint16 +} + +// SizeOfDigestMetadata is the size of struct DigestMetadata. +const SizeOfDigestMetadata = 4 + +// VerityDigest is struct from uapi/linux/fsverity.h. +type VerityDigest struct { + Metadata DigestMetadata + Digest []byte +} + // IOC outputs the result of _IOC macro in asm-generic/ioctl.h. func IOC(dir, typ, nr, size uint32) uint32 { return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT diff --git a/pkg/abi/linux/membarrier.go b/pkg/abi/linux/membarrier.go new file mode 100644 index 000000000..0f3c03e63 --- /dev/null +++ b/pkg/abi/linux/membarrier.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// membarrier(2) commands, from include/uapi/linux/membarrier.h +const ( + MEMBARRIER_CMD_QUERY = 0 + MEMBARRIER_CMD_GLOBAL = (1 << 0) + MEMBARRIER_CMD_GLOBAL_EXPEDITED = (1 << 1) + MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED = (1 << 2) + MEMBARRIER_CMD_PRIVATE_EXPEDITED = (1 << 3) + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED = (1 << 4) + MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 5) + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 6) +) diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 1c5b34711..b521144d9 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -265,6 +265,18 @@ type KernelXTEntryMatch struct { Data []byte } +// XTGetRevision corresponds to xt_get_revision in +// include/uapi/linux/netfilter/x_tables.h +// +// +marshal +type XTGetRevision struct { + Name ExtensionName + Revision uint8 +} + +// SizeOfXTGetRevision is the size of an XTGetRevision. +const SizeOfXTGetRevision = 30 + // XTEntryTarget holds a target for a rule. For example, it can specify that // packets matching the rule should DROP, ACCEPT, or use an extension target. // iptables-extension(8) has a list of possible targets. @@ -285,6 +297,13 @@ type XTEntryTarget struct { // SizeOfXTEntryTarget is the size of an XTEntryTarget. const SizeOfXTEntryTarget = 32 +// KernelXTEntryTarget is identical to XTEntryTarget, but contains a +// variable-length Data field. +type KernelXTEntryTarget struct { + XTEntryTarget + Data []byte +} + // XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, // RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. @@ -510,6 +529,8 @@ type IPTReplace struct { const SizeOfIPTReplace = 96 // ExtensionName holds the name of a netfilter extension. +// +// +marshal type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index a137940b6..6d31eb5e3 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -321,3 +321,16 @@ const ( // Enable all flags. IP6T_INV_MASK = 0x7F ) + +// NFNATRange corresponds to struct nf_nat_range in +// include/uapi/linux/netfilter/nf_nat.h. +type NFNATRange struct { + Flags uint32 + MinAddr Inet6Addr + MaxAddr Inet6Addr + MinProto uint16 // Network byte order. + MaxProto uint16 // Network byte order. +} + +// SizeOfNFNATRange is the size of NFNATRange. +const SizeOfNFNATRange = 40 diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go index b07cafe12..5be3f10f9 100644 --- a/pkg/abi/linux/seccomp.go +++ b/pkg/abi/linux/seccomp.go @@ -83,3 +83,22 @@ type SockFprog struct { pad [6]byte Filter *BPFInstruction } + +// SeccompData is equivalent to struct seccomp_data, which contains the data +// passed to seccomp-bpf filters. +// +// +marshal +type SeccompData struct { + // Nr is the system call number. + Nr int32 + + // Arch is an AUDIT_ARCH_* value indicating the system call convention. + Arch uint32 + + // InstructionPointer is the value of the instruction pointer at the time + // of the system call. + InstructionPointer uint64 + + // Args contains the first 6 system call arguments. + Args [6]uint64 +} diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go index 85fad9956..468c6a387 100644 --- a/pkg/abi/linux/signalfd.go +++ b/pkg/abi/linux/signalfd.go @@ -23,6 +23,8 @@ const ( ) // SignalfdSiginfo is the siginfo encoding for signalfds. +// +// +marshal type SignalfdSiginfo struct { Signo uint32 Errno int32 @@ -41,5 +43,5 @@ type SignalfdSiginfo struct { STime uint64 Addr uint64 AddrLSB uint16 - _ [48]uint8 + _ [48]uint8 `marshal:"unaligned"` } diff --git a/pkg/marshal/BUILD b/pkg/marshal/BUILD index 4aec98218..aac0161fa 100644 --- a/pkg/marshal/BUILD +++ b/pkg/marshal/BUILD @@ -11,7 +11,5 @@ go_library( visibility = [ "//:sandbox", ], - deps = [ - "//pkg/usermem", - ], + deps = ["//pkg/usermem"], ) diff --git a/pkg/marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD index 06741e6d1..d77a11c79 100644 --- a/pkg/marshal/primitive/BUILD +++ b/pkg/marshal/primitive/BUILD @@ -12,6 +12,7 @@ go_library( "//:sandbox", ], deps = [ + "//pkg/context", "//pkg/marshal", "//pkg/usermem", ], diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index dfdae5d60..4b342de6b 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -19,6 +19,7 @@ package primitive import ( "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" ) @@ -126,6 +127,46 @@ var _ marshal.Marshallable = (*ByteSlice)(nil) // Below, we define some convenience functions for marshalling primitive types // using the newtypes above, without requiring superfluous casts. +// 8-bit integers + +// CopyInt8In is a convenient wrapper for copying in an int8 from the task's +// memory. +func CopyInt8In(cc marshal.CopyContext, addr usermem.Addr, dst *int8) (int, error) { + var buf Int8 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = int8(buf) + return n, nil +} + +// CopyInt8Out is a convenient wrapper for copying out an int8 to the task's +// memory. +func CopyInt8Out(cc marshal.CopyContext, addr usermem.Addr, src int8) (int, error) { + srcP := Int8(src) + return srcP.CopyOut(cc, addr) +} + +// CopyUint8In is a convenient wrapper for copying in a uint8 from the task's +// memory. +func CopyUint8In(cc marshal.CopyContext, addr usermem.Addr, dst *uint8) (int, error) { + var buf Uint8 + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = uint8(buf) + return n, nil +} + +// CopyUint8Out is a convenient wrapper for copying out a uint8 to the task's +// memory. +func CopyUint8Out(cc marshal.CopyContext, addr usermem.Addr, src uint8) (int, error) { + srcP := Uint8(src) + return srcP.CopyOut(cc, addr) +} + // 16-bit integers // CopyInt16In is a convenient wrapper for copying in an int16 from the task's @@ -245,3 +286,64 @@ func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, srcP := Uint64(src) return srcP.CopyOut(cc, addr) } + +// CopyByteSliceIn is a convenient wrapper for copying in a []byte from the +// task's memory. +func CopyByteSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst *[]byte) (int, error) { + var buf ByteSlice + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = []byte(buf) + return n, nil +} + +// CopyByteSliceOut is a convenient wrapper for copying out a []byte to the +// task's memory. +func CopyByteSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []byte) (int, error) { + srcP := ByteSlice(src) + return srcP.CopyOut(cc, addr) +} + +// CopyStringIn is a convenient wrapper for copying in a string from the +// task's memory. +func CopyStringIn(cc marshal.CopyContext, addr usermem.Addr, dst *string) (int, error) { + var buf ByteSlice + n, err := buf.CopyIn(cc, addr) + if err != nil { + return n, err + } + *dst = string(buf) + return n, nil +} + +// CopyStringOut is a convenient wrapper for copying out a string to the task's +// memory. +func CopyStringOut(cc marshal.CopyContext, addr usermem.Addr, src string) (int, error) { + srcP := ByteSlice(src) + return srcP.CopyOut(cc, addr) +} + +// IOCopyContext wraps an object implementing usermem.IO to implement +// marshal.CopyContext. +type IOCopyContext struct { + Ctx context.Context + IO usermem.IO + Opts usermem.IOOpts +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (i *IOCopyContext) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (i *IOCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return i.IO.CopyOut(i.Ctx, addr, b, i.Opts) +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (i *IOCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return i.IO.CopyIn(i.Ctx, addr, b, i.Opts) +} diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 4b4f9bd52..ccef83259 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -41,7 +41,7 @@ type Layout struct { blockSize int64 // digestSize is the size of a generated hash. digestSize int64 - // levelOffset contains the offset of the begnning of each level in + // levelOffset contains the offset of the beginning of each level in // bytes. The number of levels in the tree is the length of the slice. // The leaf nodes (level 0) contain hashes of blocks of the input data. // Each level N contains hashes of the blocks in level N-1. The highest diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index bdef7762c..e828894b0 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -49,7 +49,7 @@ go_test( library = ":seccomp", deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/bpf", + "//pkg/usermem", ], ) diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 23f30678d..e1444d18b 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -28,17 +28,10 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/usermem" ) -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - // newVictim makes a victim binary. func newVictim() (string, error) { f, err := ioutil.TempFile("", "victim") @@ -58,9 +51,14 @@ func newVictim() (string, error) { return path, nil } -// asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +// dataAsInput converts a linux.SeccompData to a bpf.Input. +func dataAsInput(d *linux.SeccompData) bpf.Input { + buf := make([]byte, d.SizeBytes()) + d.MarshalUnsafe(buf) + return bpf.InputBytes{ + Data: buf, + Order: usermem.ByteOrder, + } } func TestBasic(t *testing.T) { @@ -69,7 +67,7 @@ func TestBasic(t *testing.T) { desc string // data is the input data. - data seccompData + data linux.SeccompData // want is the expected return value of the BPF program. want linux.BPFAction @@ -95,12 +93,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "syscall allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "syscall disallowed", - data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -131,22 +129,22 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed (1a)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "allowed (1b)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "syscall 1 matched 2nd rule", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "no match", - data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, @@ -168,42 +166,42 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed (1)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "allowed (3)", - data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 3, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "allowed (5)", - data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 5, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "disallowed (0)", - data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (2)", - data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (4)", - data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 4, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (6)", - data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 6, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (100)", - data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 100, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -223,7 +221,7 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arch (123)", - data: seccompData{nr: 1, arch: 123}, + data: linux.SeccompData{Nr: 1, Arch: 123}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, @@ -243,7 +241,7 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "action trap", - data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -268,12 +266,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "disallowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xe}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -300,12 +298,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "match first rule", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "match 2nd rule", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xe}}, want: linux.SECCOMP_RET_ALLOW, }, }, @@ -331,28 +329,28 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "argument allowed (all match)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "argument disallowed (one mismatch)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "argument disallowed (multiple mismatch)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_TRAP, }, @@ -379,28 +377,28 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (one equal)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (all equal)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, }, want: linux.SECCOMP_RET_TRAP, }, @@ -429,27 +427,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000003}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -474,27 +472,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (first arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -522,27 +520,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -567,32 +565,32 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed (both greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg allowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg allowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (second arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (both arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -620,27 +618,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, }, @@ -665,32 +663,32 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (first arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (both arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -718,27 +716,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, }, @@ -764,32 +762,32 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg allowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg allowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (second arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (both arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -816,51 +814,51 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed (low order mandatory bit)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000001 - args: [6]uint64{0x1}, + Args: [6]uint64{0x1}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg allowed (low order optional bit)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000101 - args: [6]uint64{0x5}, + Args: [6]uint64{0x5}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (lowest order bit not set)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000010 - args: [6]uint64{0x2}, + Args: [6]uint64{0x2}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second lowest order bit set)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000011 - args: [6]uint64{0x3}, + Args: [6]uint64{0x3}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (8th bit set)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000001 00000000 - args: [6]uint64{0x100}, + Args: [6]uint64{0x100}, }, want: linux.SECCOMP_RET_TRAP, }, @@ -885,12 +883,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x7aabbccdd}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "disallowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x711223344}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -906,7 +904,7 @@ func TestBasic(t *testing.T) { t.Fatalf("bpf.Compile() got error: %v", err) } for _, spec := range test.specs { - got, err := bpf.Exec(p, spec.data.asInput()) + got, err := bpf.Exec(p, dataAsInput(&spec.data)) if err != nil { t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err) } @@ -947,8 +945,8 @@ func TestRandom(t *testing.T) { t.Fatalf("bpf.Compile() got error: %v", err) } for i := uint32(0); i < 200; i++ { - data := seccompData{nr: i, arch: LINUX_AUDIT_ARCH} - got, err := bpf.Exec(p, data.asInput()) + data := linux.SeccompData{Nr: int32(i), Arch: LINUX_AUDIT_ARCH} + got, err := bpf.Exec(p, dataAsInput(&data)) if err != nil { t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i) continue diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 99e2b3389..4af4d6e84 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -22,6 +22,7 @@ go_library( "signal_info.go", "signal_stack.go", "stack.go", + "stack_unsafe.go", "syscalls_amd64.go", "syscalls_arm64.go", ], diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 0f433ee79..fd73751e7 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -154,6 +154,7 @@ func (s State) Proto() *rpb.Registers { Sp: s.Regs.Sp, Pc: s.Regs.Pc, Pstate: s.Regs.Pstate, + Tls: s.Regs.TPIDR_EL0, } return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}} } @@ -232,6 +233,7 @@ func (s *State) RegisterMap() (map[string]uintptr, error) { "Sp": uintptr(s.Regs.Sp), "Pc": uintptr(s.Regs.Pc), "Pstate": uintptr(s.Regs.Pstate), + "Tls": uintptr(s.Regs.TPIDR_EL0), }, nil } diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto index 60c027aab..2727ba08a 100644 --- a/pkg/sentry/arch/registers.proto +++ b/pkg/sentry/arch/registers.proto @@ -83,6 +83,7 @@ message ARM64Registers { uint64 sp = 32; uint64 pc = 33; uint64 pstate = 34; + uint64 tls = 35; } message Registers { oneof arch { diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 6fb756f0e..72e07a988 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -17,17 +17,19 @@ package arch import ( - "encoding/binary" "math" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" ) // SignalContext64 is equivalent to struct sigcontext, the type passed as the // second argument to signal handlers set by signal(2). +// +// +marshal type SignalContext64 struct { R8 uint64 R9 uint64 @@ -68,6 +70,8 @@ const ( ) // UContext64 is equivalent to ucontext_t on 64-bit x86. +// +// +marshal type UContext64 struct { Flags uint64 Link uint64 @@ -172,12 +176,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // "... the value (%rsp+8) is always a multiple of 16 (...) when // control is transferred to the function entry point." - AMD64 ABI - ucSize := binary.Size(uc) - if ucSize < 0 { - // This can only happen if we've screwed up the definition of - // UContext64. - panic("can't get size of UContext64") - } + ucSize := uc.SizeBytes() // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128. frameSize := int(st.Arch.Width()) + ucSize + 128 frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8 @@ -195,18 +194,18 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt info.FixSignalCodeForUser() // Set up the stack frame. - infoAddr, err := st.Push(info) - if err != nil { + if _, err := info.CopyOut(st, StackBottomMagic); err != nil { return err } - ucAddr, err := st.Push(uc) - if err != nil { + infoAddr := st.Bottom + if _, err := uc.CopyOut(st, StackBottomMagic); err != nil { return err } + ucAddr := st.Bottom if act.HasRestorer() { // Push the restorer return address. // Note that this doesn't need to be popped. - if _, err := st.Push(usermem.Addr(act.Restorer)); err != nil { + if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil { return err } } else { @@ -240,11 +239,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { // Copy out the stack frame. var uc UContext64 - if _, err := st.Pop(&uc); err != nil { + if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } var info SignalInfo - if _, err := st.Pop(&info); err != nil { + if _, err := info.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 642c79dda..7fde5d34e 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build arm64 + package arch import ( - "encoding/binary" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +26,8 @@ import ( // SignalContext64 is equivalent to struct sigcontext, the type passed as the // second argument to signal handlers set by signal(2). +// +// +marshal type SignalContext64 struct { FaultAddr uint64 Regs [31]uint64 @@ -36,6 +39,7 @@ type SignalContext64 struct { Reserved [3568]uint8 } +// +marshal type aarch64Ctx struct { Magic uint32 Size uint32 @@ -43,6 +47,8 @@ type aarch64Ctx struct { // FpsimdContext is equivalent to struct fpsimd_context on arm64 // (arch/arm64/include/uapi/asm/sigcontext.h). +// +// +marshal type FpsimdContext struct { Head aarch64Ctx Fpsr uint32 @@ -51,13 +57,15 @@ type FpsimdContext struct { } // UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h). +// +// +marshal type UContext64 struct { Flags uint64 Link uint64 Stack SignalStack Sigset linux.SignalSet // glibc uses a 1024-bit sigset_t - _pad [(1024 - 64) / 8]byte + _pad [120]byte // (1024 - 64) / 8 = 120 // sigcontext must be aligned to 16-byte _pad2 [8]byte // last for future expansion @@ -94,11 +102,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt }, Sigset: sigset, } - - ucSize := binary.Size(uc) - if ucSize < 0 { - panic("can't get size of UContext64") - } + ucSize := uc.SizeBytes() // frameSize = ucSize + sizeof(siginfo). // sizeof(siginfo) == 128. @@ -119,14 +123,14 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt info.FixSignalCodeForUser() // Set up the stack frame. - infoAddr, err := st.Push(info) - if err != nil { + if _, err := info.CopyOut(st, StackBottomMagic); err != nil { return err } - ucAddr, err := st.Push(uc) - if err != nil { + infoAddr := st.Bottom + if _, err := uc.CopyOut(st, StackBottomMagic); err != nil { return err } + ucAddr := st.Bottom // Set up registers. c.Regs.Sp = uint64(st.Bottom) @@ -147,11 +151,11 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { // Copy out the stack frame. var uc UContext64 - if _, err := st.Pop(&uc); err != nil { + if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } var info SignalInfo - if _, err := st.Pop(&info); err != nil { + if _, err := info.CopyIn(st, StackBottomMagic); err != nil { return 0, SignalStack{}, err } diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go index 1108fa0bd..5f06c751d 100644 --- a/pkg/sentry/arch/stack.go +++ b/pkg/sentry/arch/stack.go @@ -15,14 +15,16 @@ package arch import ( - "encoding/binary" - "fmt" - "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" ) -// Stack is a simple wrapper around a usermem.IO and an address. +// Stack is a simple wrapper around a usermem.IO and an address. Stack +// implements marshal.CopyContext, and marshallable values can be pushed or +// popped from the stack through the marshal.Marshallable interface. +// +// Stack is not thread-safe. type Stack struct { // Our arch info. // We use this for automatic Native conversion of usermem.Addrs during @@ -34,105 +36,60 @@ type Stack struct { // Our current stack bottom. Bottom usermem.Addr -} -// Push pushes the given values on to the stack. -// -// (This method supports Addrs and treats them as native types.) -func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) { - for _, v := range vals { - - // We convert some types to well-known serializable quanities. - var norm interface{} - - // For array types, we will automatically add an appropriate - // terminal value. This is done simply to make the interface - // easier to use. - var term interface{} - - switch v.(type) { - case string: - norm = []byte(v.(string)) - term = byte(0) - case []int8, []uint8: - norm = v - term = byte(0) - case []int16, []uint16: - norm = v - term = uint16(0) - case []int32, []uint32: - norm = v - term = uint32(0) - case []int64, []uint64: - norm = v - term = uint64(0) - case []usermem.Addr: - // Special case: simply push recursively. - _, err := s.Push(s.Arch.Native(uintptr(0))) - if err != nil { - return 0, err - } - varr := v.([]usermem.Addr) - for i := len(varr) - 1; i >= 0; i-- { - _, err := s.Push(varr[i]) - if err != nil { - return 0, err - } - } - continue - case usermem.Addr: - norm = s.Arch.Native(uintptr(v.(usermem.Addr))) - default: - norm = v - } + // Scratch buffer used for marshalling to avoid having to repeatedly + // allocate scratch memory. + scratchBuf []byte +} - if term != nil { - _, err := s.Push(term) - if err != nil { - return 0, err - } - } +// scratchBufLen is the default length of Stack.scratchBuf. The +// largest structs the stack regularly serializes are arch.SignalInfo +// and arch.UContext64. We'll set the default size as the larger of +// the two, arch.UContext64. +var scratchBufLen = (*UContext64)(nil).SizeBytes() - c := binary.Size(norm) - if c < 0 { - return 0, fmt.Errorf("bad binary.Size for %T", v) - } - n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{}) - if err != nil || c != n { - return 0, err - } +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (s *Stack) CopyScratchBuffer(size int) []byte { + if len(s.scratchBuf) < size { + s.scratchBuf = make([]byte, size) + } + return s.scratchBuf[:size] +} +// StackBottomMagic is the special address callers must past to all stack +// marshalling operations to cause the src/dst address to be computed based on +// the current end of the stack. +const StackBottomMagic = ^usermem.Addr(0) // usermem.Addr(-1) + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. CopyOutBytes +// computes an appropriate address based on the current end of the +// stack. Callers use the sentinel address StackBottomMagic to marshal methods +// to indicate this. +func (s *Stack) CopyOutBytes(sentinel usermem.Addr, b []byte) (int, error) { + if sentinel != StackBottomMagic { + panic("Attempted to copy out to stack with absolute address") + } + c := len(b) + n, err := s.IO.CopyOut(context.Background(), s.Bottom-usermem.Addr(c), b, usermem.IOOpts{}) + if err == nil && n == c { s.Bottom -= usermem.Addr(n) } - - return s.Bottom, nil + return n, err } -// Pop pops the given values off the stack. -// -// (This method supports Addrs and treats them as native types.) -func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) { - for _, v := range vals { - - vaddr, isVaddr := v.(*usermem.Addr) - - var n int - var err error - if isVaddr { - value := s.Arch.Native(uintptr(0)) - n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{}) - *vaddr = usermem.Addr(s.Arch.Value(value)) - } else { - n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{}) - } - if err != nil { - return 0, err - } - +// CopyInBytes implements marshal.CopyContext.CopyInBytes. CopyInBytes computes +// an appropriate address based on the current end of the stack. Callers must +// use the sentinel address StackBottomMagic to marshal methods to indicate +// this. +func (s *Stack) CopyInBytes(sentinel usermem.Addr, b []byte) (int, error) { + if sentinel != StackBottomMagic { + panic("Attempted to copy in from stack with absolute address") + } + n, err := s.IO.CopyIn(context.Background(), s.Bottom, b, usermem.IOOpts{}) + if err == nil { s.Bottom += usermem.Addr(n) } - - return s.Bottom, nil + return n, err } // Align aligns the stack to the given offset. @@ -142,6 +99,22 @@ func (s *Stack) Align(offset int) { } } +// PushNullTerminatedByteSlice writes bs to the stack, followed by an extra null +// byte at the end. On error, the contents of the stack and the bottom cursor +// are undefined. +func (s *Stack) PushNullTerminatedByteSlice(bs []byte) (int, error) { + // Note: Stack grows up, so write the terminal null byte first. + nNull, err := primitive.CopyUint8Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + n, err := primitive.CopyByteSliceOut(s, StackBottomMagic, bs) + if err != nil { + return 0, err + } + return n + nNull, nil +} + // StackLayout describes the location of the arguments and environment on the // stack. type StackLayout struct { @@ -177,11 +150,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) l.EnvvEnd = s.Bottom envAddrs := make([]usermem.Addr, len(env)) for i := len(env) - 1; i >= 0; i-- { - addr, err := s.Push(env[i]) - if err != nil { + if _, err := s.PushNullTerminatedByteSlice([]byte(env[i])); err != nil { return StackLayout{}, err } - envAddrs[i] = addr + envAddrs[i] = s.Bottom } l.EnvvStart = s.Bottom @@ -189,11 +161,10 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) l.ArgvEnd = s.Bottom argAddrs := make([]usermem.Addr, len(args)) for i := len(args) - 1; i >= 0; i-- { - addr, err := s.Push(args[i]) - if err != nil { + if _, err := s.PushNullTerminatedByteSlice([]byte(args[i])); err != nil { return StackLayout{}, err } - argAddrs[i] = addr + argAddrs[i] = s.Bottom } l.ArgvStart = s.Bottom @@ -222,26 +193,26 @@ func (s *Stack) Load(args []string, env []string, aux Auxv) (StackLayout, error) auxv = append(auxv, usermem.Addr(a.Key), a.Value) } auxv = append(auxv, usermem.Addr(0)) - _, err := s.Push(auxv) + _, err := s.pushAddrSliceAndTerminator(auxv) if err != nil { return StackLayout{}, err } // Push environment. - _, err = s.Push(envAddrs) + _, err = s.pushAddrSliceAndTerminator(envAddrs) if err != nil { return StackLayout{}, err } // Push args. - _, err = s.Push(argAddrs) + _, err = s.pushAddrSliceAndTerminator(argAddrs) if err != nil { return StackLayout{}, err } // Push arg count. - _, err = s.Push(usermem.Addr(len(args))) - if err != nil { + lenP := s.Arch.Native(uintptr(len(args))) + if _, err = lenP.CopyOut(s, StackBottomMagic); err != nil { return StackLayout{}, err } diff --git a/pkg/sentry/arch/stack_unsafe.go b/pkg/sentry/arch/stack_unsafe.go new file mode 100644 index 000000000..a90d297ee --- /dev/null +++ b/pkg/sentry/arch/stack_unsafe.go @@ -0,0 +1,69 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arch + +import ( + "reflect" + "runtime" + "unsafe" + + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/usermem" +) + +// pushAddrSliceAndTerminator copies a slices of addresses to the stack, and +// also pushes an extra null address element at the end of the slice. +// +// Internally, we unsafely transmute the slice type from the arch-dependent +// []usermem.Addr type, to a slice of fixed-sized ints so that we can pass it to +// go-marshal. +// +// On error, the contents of the stack and the bottom cursor are undefined. +func (s *Stack) pushAddrSliceAndTerminator(src []usermem.Addr) (int, error) { + // Note: Stack grows upwards, so push the terminator first. + srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + switch s.Arch.Width() { + case 8: + nNull, err := primitive.CopyUint64Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + var dst []uint64 + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + dstHdr.Data = srcHdr.Data + dstHdr.Len = srcHdr.Len + dstHdr.Cap = srcHdr.Cap + n, err := primitive.CopyUint64SliceOut(s, StackBottomMagic, dst) + // Ensures src doesn't get GCed until we're done using it through dst. + runtime.KeepAlive(src) + return n + nNull, err + case 4: + nNull, err := primitive.CopyUint32Out(s, StackBottomMagic, 0) + if err != nil { + return 0, err + } + var dst []uint32 + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + dstHdr.Data = srcHdr.Data + dstHdr.Len = srcHdr.Len + dstHdr.Cap = srcHdr.Cap + n, err := primitive.CopyUint32SliceOut(s, StackBottomMagic, dst) + // Ensure src doesn't get GCed until we're done using it through dst. + runtime.KeepAlive(src) + return n + nNull, err + default: + panic("Unsupported arch width") + } +} diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 103bfc600..22d658acf 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -84,6 +84,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo "auxv": newAuxvec(t, msrc), "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), "comm": newComm(t, msrc), + "cwd": newCwd(t, msrc), "environ": newExecArgInode(t, msrc, environExecArg), "exe": newExe(t, msrc), "fd": newFdDir(t, msrc), @@ -300,6 +301,49 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { return exec.PathnameWithDeleted(ctx), nil } +// cwd is an fs.InodeOperations symlink for the /proc/PID/cwd file. +// +// +stateify savable +type cwd struct { + ramfs.Symlink + + t *kernel.Task +} + +func newCwd(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + cwdSymlink := &cwd{ + Symlink: *ramfs.NewSymlink(t, fs.RootOwner, ""), + t: t, + } + return newProcInode(t, cwdSymlink, msrc, fs.Symlink, t) +} + +// Readlink implements fs.InodeOperations. +func (e *cwd) Readlink(ctx context.Context, inode *fs.Inode) (string, error) { + if !kernel.ContextCanTrace(ctx, e.t, false) { + return "", syserror.EACCES + } + if err := checkTaskState(e.t); err != nil { + return "", err + } + cwd := e.t.FSContext().WorkingDirectory() + if cwd == nil { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer cwd.DecRef(ctx) + + root := fs.RootFromContext(ctx) + if root == nil { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + name, _ := cwd.FullName(root) + return name, nil +} + // namespaceSymlink represents a symlink in the namespacefs, such as the files // in /proc/<pid>/ns. // diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index 323506d33..be0900030 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -122,7 +122,7 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) // remainingTraversals is not configurable in VFS2, all callers are using the // default anyways. func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { - vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem() + vfsObj := l.root.Mount().Filesystem().VirtualFilesystem() creds := auth.CredentialsFromContext(ctx) path := fspath.Parse(pathname) pop := &vfs.PathOperation{ diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index abc610ef3..7b1eec3da 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/fd", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", @@ -86,9 +88,9 @@ go_test( library = ":ext", deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fspath", + "//pkg/marshal/primitive", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index 8bb104ff0..1165234f9 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -18,7 +18,7 @@ import ( "io" "math" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/syserror" ) @@ -34,19 +34,19 @@ type blockMapFile struct { // directBlks are the direct blocks numbers. The physical blocks pointed by // these holds file data. Contains file blocks 0 to 11. - directBlks [numDirectBlks]uint32 + directBlks [numDirectBlks]primitive.Uint32 // indirectBlk is the physical block which contains (blkSize/4) direct block // numbers (as uint32 integers). - indirectBlk uint32 + indirectBlk primitive.Uint32 // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect // block numbers (as uint32 integers). - doubleIndirectBlk uint32 + doubleIndirectBlk primitive.Uint32 // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly // indirect block numbers (as uint32 integers). - tripleIndirectBlk uint32 + tripleIndirectBlk primitive.Uint32 // coverage at (i)th index indicates the amount of file data a node at // height (i) covers. Height 0 is the direct block. @@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) { } blkMap := file.regFile.inode.diskInode.Data() - binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks) - binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk) + for i := 0; i < numDirectBlks; i++ { + file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4]) + } + file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4]) + file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4]) + file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4]) return file, nil } @@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) { switch { case offset < dirBlksEnd: // Direct block. - curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:]) + curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:]) case offset < indirBlkEnd: // Indirect block. - curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:]) + curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:]) case offset < doubIndirBlkEnd: // Doubly indirect block. - curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:]) + curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:]) default: // Triply indirect block. - curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:]) + curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:]) } read += curR @@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds read := 0 curChildOff := relFileOff % childCov for i := startIdx; i < endIdx; i++ { - var childPhyBlk uint32 + var childPhyBlk primitive.Uint32 err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) if err != nil { return read, err } - n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:]) + n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:]) read += n if err != nil { return read, err diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go index 6fa84e7aa..ed98b482e 100644 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -20,7 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { mockDisk := make([]byte, mockBMDiskSize) var fileData []byte blkNums := newBlkNumGen() - var data []byte + off := 0 + data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes()) // Write the direct blocks. for i := 0; i < numDirectBlks; i++ { - curBlkNum := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, curBlkNum) - fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(data[off:]) + off += curBlkNum.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...) } // Write to indirect block. - indirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, indirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...) - - // Write to indirect block. - doublyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...) - - // Write to indirect block. - triplyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...) + indirectBlk := primitive.Uint32(blkNums.next()) + indirectBlk.MarshalBytes(data[off:]) + off += indirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...) + + // Write to double indirect block. + doublyIndirectBlk := primitive.Uint32(blkNums.next()) + doublyIndirectBlk.MarshalBytes(data[off:]) + off += doublyIndirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...) + + // Write to triple indirect block. + triplyIndirectBlk := primitive.Uint32(blkNums.next()) + triplyIndirectBlk.MarshalBytes(data[off:]) + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...) args := inodeArgs{ fs: &filesystem{ @@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN var fileData []byte for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 { - curBlkNum := blkNums.next() - copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum)) - fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(disk[off : off+4]) + fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...) } return fileData } diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 452450d82..0ad79b381 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -16,7 +16,6 @@ package ext import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -100,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) { } else { curDirent.diskDirent = &disklayout.DirentOld{} } - binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent) + curDirent.diskDirent.UnmarshalBytes(buf) if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 { // Inode number and name length fields being set to 0 is used to indicate diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 9bd9c76c0..d98a05dd8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -22,10 +22,11 @@ go_library( "superblock_old.go", "test_utils.go", ], + marshal = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/marshal", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go index ad6f4fef8..0d56ae9da 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // BlockGroup represents a Linux ext block group descriptor. An ext file system // is split into a series of block groups. This provides an access layer to // information needed to access and use a block group. @@ -30,6 +34,8 @@ package disklayout // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors. type BlockGroup interface { + marshal.Marshallable + // InodeTable returns the absolute block number of the block containing the // inode table. This points to an array of Inode structs. Inode tables are // statically allocated at mkfs time. The superblock records the number of diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go index 3e16c76db..a35fa22a0 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go @@ -17,6 +17,8 @@ package disklayout // BlockGroup32Bit emulates the first half of struct ext4_group_desc in // fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and // 32-bit ext4 filesystems. It implements BlockGroup interface. +// +// +marshal type BlockGroup32Bit struct { BlockBitmapLo uint32 InodeBitmapLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go index 9a809197a..d54d1d345 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go @@ -18,6 +18,8 @@ package disklayout // It is the block group descriptor struct for 64-bit ext4 filesystems. // It implements BlockGroup interface. It is an extension of the 32-bit // version of BlockGroup. +// +// +marshal type BlockGroup64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go index 0ef4294c0..e4ce484e4 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go @@ -21,6 +21,8 @@ import ( // TestBlockGroupSize tests that the block group descriptor structs are of the // correct size. func TestBlockGroupSize(t *testing.T) { - assertSize(t, BlockGroup32Bit{}, 32) - assertSize(t, BlockGroup64Bit{}, 64) + var bgSmall BlockGroup32Bit + assertSize(t, &bgSmall, 32) + var bgBig BlockGroup64Bit + assertSize(t, &bgBig, 64) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go index 417b6cf65..568c8cb4c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go @@ -15,6 +15,7 @@ package disklayout import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fs" ) @@ -51,6 +52,8 @@ var ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories. type Dirent interface { + marshal.Marshallable + // Inode returns the absolute inode number of the underlying inode. // Inode number 0 signifies an unused dirent. Inode() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go index 29ae4a5c2..51f9c2946 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go @@ -29,12 +29,14 @@ import ( // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentNew struct { InodeNumber uint32 RecordLength uint16 NameLength uint8 FileTypeRaw uint8 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentNew implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go index 6fff12a6e..d4b19e086 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go @@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs" // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentOld struct { InodeNumber uint32 RecordLength uint16 NameLength uint16 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentOld implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go index 934919f8a..3486864dc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go @@ -21,6 +21,8 @@ import ( // TestDirentSize tests that the dirent structs are of the correct // size. func TestDirentSize(t *testing.T) { - assertSize(t, DirentOld{}, uintptr(DirentSize)) - assertSize(t, DirentNew{}, uintptr(DirentSize)) + var dOld DirentOld + assertSize(t, &dOld, DirentSize) + var dNew DirentNew + assertSize(t, &dNew, DirentSize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go index bdf4e2132..0834e9ba8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go +++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go @@ -36,8 +36,6 @@ // escape analysis on an unknown implementation at compile time. // // Notes: -// - All fields in these structs are exported because binary.Read would -// panic otherwise. // - All structures on disk are in little-endian order. Only jbd2 (journal) // structures are in big-endian order. // - All OS dependent fields in these structures will be interpretted using diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go index 4110649ab..b13999bfc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // Extents were introduced in ext4 and provide huge performance gains in terms // data locality and reduced metadata block usage. Extents are organized in // extent trees. The root node is contained in inode.BlocksRaw. @@ -64,6 +68,8 @@ type ExtentNode struct { // ExtentEntry represents an extent tree node entry. The entry can either be // an ExtentIdx or Extent itself. This exists to simplify navigation logic. type ExtentEntry interface { + marshal.Marshallable + // FileBlock returns the first file block number covered by this entry. FileBlock() uint32 @@ -75,6 +81,8 @@ type ExtentEntry interface { // tree node begins with this and is followed by `NumEntries` number of: // - Extent if `Depth` == 0 // - ExtentIdx otherwise +// +// +marshal type ExtentHeader struct { // Magic in the extent magic number, must be 0xf30a. Magic uint16 @@ -96,6 +104,8 @@ type ExtentHeader struct { // internal nodes. Sorted in ascending order based on FirstFileBlock since // Linux does a binary search on this. This points to a block containing the // child node. +// +// +marshal type ExtentIdx struct { FirstFileBlock uint32 ChildBlockLo uint32 @@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 { // nodes. Sorted in ascending order based on FirstFileBlock since Linux does a // binary search on this. This points to an array of data blocks containing the // file data. It covers `Length` data blocks starting from `StartBlock`. +// +// +marshal type Extent struct { FirstFileBlock uint32 Length uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go index 8762b90db..c96002e19 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go @@ -21,7 +21,10 @@ import ( // TestExtentSize tests that the extent structs are of the correct // size. func TestExtentSize(t *testing.T) { - assertSize(t, ExtentHeader{}, ExtentHeaderSize) - assertSize(t, ExtentIdx{}, ExtentEntrySize) - assertSize(t, Extent{}, ExtentEntrySize) + var h ExtentHeader + assertSize(t, &h, ExtentHeaderSize) + var i ExtentIdx + assertSize(t, &i, ExtentEntrySize) + var e Extent + assertSize(t, &e, ExtentEntrySize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go index 88ae913f5..ef25040a9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go @@ -16,6 +16,7 @@ package disklayout import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) @@ -38,6 +39,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes. type Inode interface { + marshal.Marshallable + // Mode returns the linux file mode which is majorly used to extract // information like: // - File permissions (read/write/execute by user/group/others). diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go index 8f9f574ce..a4503f5cf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go @@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time" // are used to provide nanoscond precision. Hence, these timestamps will now // overflow in May 2446. // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps. +// +// +marshal type InodeNew struct { InodeOld diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go index db25b11b6..e6b28babf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go @@ -30,6 +30,8 @@ const ( // // All fields representing time are in seconds since the epoch. Which means that // they will overflow in January 2038. +// +// +marshal type InodeOld struct { ModeRaw uint16 UIDLo uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go index dd03ee50e..90744e956 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go @@ -24,10 +24,12 @@ import ( // TestInodeSize tests that the inode structs are of the correct size. func TestInodeSize(t *testing.T) { - assertSize(t, InodeOld{}, OldInodeSize) + var iOld InodeOld + assertSize(t, &iOld, OldInodeSize) // This was updated from 156 bytes to 160 bytes in Oct 2015. - assertSize(t, InodeNew{}, 160) + var iNew InodeNew + assertSize(t, &iNew, 160) } // TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go index 8bb327006..70948ebe9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + const ( // SbOffset is the absolute offset at which the superblock is placed. SbOffset = 1024 @@ -38,6 +42,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block. type SuperBlock interface { + marshal.Marshallable + // InodesCount returns the total number of inodes in this filesystem. InodesCount() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go index 53e515fd3..4dc6080fb 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go @@ -17,6 +17,8 @@ package disklayout // SuperBlock32Bit implements SuperBlock and represents the 32-bit version of // the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if // RevLevel = DynamicRev and 64-bit feature is disabled. +// +// +marshal type SuperBlock32Bit struct { // We embed the old superblock struct here because the 32-bit version is just // an extension of the old version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go index 7c1053fb4..2c9039327 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go @@ -19,6 +19,8 @@ package disklayout // 1024 bytes (smallest possible block size) and hence the superblock always // fits in no more than one data block. Should only be used when the 64-bit // feature is set. +// +// +marshal type SuperBlock64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go index 9221e0251..e4709f23c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go @@ -16,6 +16,8 @@ package disklayout // SuperBlockOld implements SuperBlock and represents the old version of the // superblock struct. Should be used only if RevLevel = OldRev. +// +// +marshal type SuperBlockOld struct { InodesCountRaw uint32 BlocksCountLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go index 463b5ba21..b734b6987 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go @@ -21,7 +21,10 @@ import ( // TestSuperBlockSize tests that the superblock structs are of the correct // size. func TestSuperBlockSize(t *testing.T) { - assertSize(t, SuperBlockOld{}, 84) - assertSize(t, SuperBlock32Bit{}, 336) - assertSize(t, SuperBlock64Bit{}, 1024) + var sbOld SuperBlockOld + assertSize(t, &sbOld, 84) + var sb32 SuperBlock32Bit + assertSize(t, &sb32, 336) + var sb64 SuperBlock64Bit + assertSize(t, &sb64, 1024) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go index 9c63f04c0..a4bc08411 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go +++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go @@ -18,13 +18,13 @@ import ( "reflect" "testing" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" ) -func assertSize(t *testing.T, v interface{}, want uintptr) { +func assertSize(t *testing.T, v marshal.Marshallable, want int) { t.Helper() - if got := binary.Size(v); got != want { + if got := v.SizeBytes(); got != want { t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got) } } diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index 04917d762..778460107 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -18,7 +18,6 @@ import ( "io" "sort" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) @@ -60,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) { func (f *extentFile) buildExtTree() error { rootNodeData := f.regFile.inode.diskInode.Data() - binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header) + f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize]) // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. if f.root.Header.NumEntries > 4 { @@ -79,7 +78,7 @@ func (f *extentFile) buildExtTree() error { // Internal node. curEntry = &disklayout.ExtentIdx{} } - binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry) + curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize]) f.root.Entries[i].Entry = curEntry } diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index cd10d46ee..985f76ac0 100644 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -21,7 +21,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, [] // writeTree writes the tree represented by `root` to the inode and disk. It // also writes random file data on disk. func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte { - rootData := binary.Marshal(nil, binary.LittleEndian, root.Header) + rootData := in.diskInode.Data() + root.Header.MarshalBytes(rootData) + off := root.Header.SizeBytes() for _, ep := range root.Entries { - rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(rootData[off:]) + off += ep.Entry.SizeBytes() } - copy(in.diskInode.Data(), rootData) - var fileData []byte for _, ep := range root.Entries { if root.Header.Height == 0 { @@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl // writeTreeToDisk is the recursive step for writeTree which writes the tree // on the disk only. Also writes random file data on disk. func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte { - nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header) + nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:] + curNode.Node.Header.MarshalBytes(nodeData) + off := curNode.Node.Header.SizeBytes() for _, ep := range curNode.Node.Entries { - nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(nodeData[off:]) + off += ep.Entry.SizeBytes() } - copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData) - var fileData []byte for _, ep := range curNode.Node.Entries { if curNode.Node.Header.Height == 0 { diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go index d8b728f8c..58ef7b9b8 100644 --- a/pkg/sentry/fsimpl/ext/utils.go +++ b/pkg/sentry/fsimpl/ext/utils.go @@ -17,21 +17,21 @@ package ext import ( "io" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) // readFromDisk performs a binary read from disk into the given struct from // the absolute offset provided. -func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error { - n := binary.Size(v) +func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error { + n := v.SizeBytes() buf := make([]byte, n) if read, _ := dev.ReadAt(buf, abOff); read < int(n) { return syserror.EIO } - binary.Unmarshal(buf, binary.LittleEndian, v) + v.UnmarshalBytes(buf) return nil } diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 131145b85..8a447e29f 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -348,10 +348,10 @@ func (e *SCMConnectedEndpoint) Init() error { func (e *SCMConnectedEndpoint) Release(ctx context.Context) { e.DecRef(func() { e.mu.Lock() + fdnotifier.RemoveFD(int32(e.fd)) if err := syscall.Close(e.fd); err != nil { log.Warningf("Failed to close host fd %d: %v", err) } - fdnotifier.RemoveFD(int32(e.fd)) e.destroyLocked() e.mu.Unlock() }) diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 1f99183eb..a7cd6f57e 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -53,6 +53,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}), "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}), "comm": fs.newComm(task, fs.NextIno(), 0444), + "cwd": fs.newCwdSymlink(task, fs.NextIno()), "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}), "exe": fs.newExeSymlink(task, fs.NextIno()), "fd": fs.newFDDirInode(task), diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index b81c8279e..3fbf081a6 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -669,18 +669,22 @@ func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentr // Readlink implements kernfs.Inode.Readlink. func (s *exeSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { - if !kernel.ContextCanTrace(ctx, s.task, false) { - return "", syserror.EACCES - } - - // Pull out the executable for /proc/[pid]/exe. - exec, err := s.executable() + exec, _, err := s.Getlink(ctx, nil) if err != nil { return "", err } defer exec.DecRef(ctx) - return exec.PathnameWithDeleted(ctx), nil + root := vfs.RootFromContext(ctx) + if !root.Ok() { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + vfsObj := exec.Mount().Filesystem().VirtualFilesystem() + name, _ := vfsObj.PathnameWithDeleted(ctx, root, exec) + return name, nil } // Getlink implements kernfs.Inode.Getlink. @@ -688,23 +692,12 @@ func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDent if !kernel.ContextCanTrace(ctx, s.task, false) { return vfs.VirtualDentry{}, "", syserror.EACCES } - - exec, err := s.executable() - if err != nil { - return vfs.VirtualDentry{}, "", err - } - defer exec.DecRef(ctx) - - vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() - vd.IncRef() - return vd, "", nil -} - -func (s *exeSymlink) executable() (file fsbridge.File, err error) { if err := checkTaskState(s.task); err != nil { - return nil, err + return vfs.VirtualDentry{}, "", err } + var err error + var exec fsbridge.File s.task.WithMuLocked(func(t *kernel.Task) { mm := t.MemoryManager() if mm == nil { @@ -715,12 +708,78 @@ func (s *exeSymlink) executable() (file fsbridge.File, err error) { // The MemoryManager may be destroyed, in which case // MemoryManager.destroy will simply set the executable to nil // (with locks held). - file = mm.Executable() - if file == nil { + exec = mm.Executable() + if exec == nil { err = syserror.ESRCH } }) - return + if err != nil { + return vfs.VirtualDentry{}, "", err + } + defer exec.DecRef(ctx) + + vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry() + vd.IncRef() + return vd, "", nil +} + +// cwdSymlink is an symlink for the /proc/[pid]/cwd file. +// +// +stateify savable +type cwdSymlink struct { + implStatFS + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeSymlink + + task *kernel.Task +} + +var _ kernfs.Inode = (*cwdSymlink)(nil) + +func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry { + inode := &cwdSymlink{task: task} + inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + + d := &kernfs.Dentry{} + d.Init(inode) + return d +} + +// Readlink implements kernfs.Inode.Readlink. +func (s *cwdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { + cwd, _, err := s.Getlink(ctx, nil) + if err != nil { + return "", err + } + defer cwd.DecRef(ctx) + + root := vfs.RootFromContext(ctx) + if !root.Ok() { + // It could have raced with process deletion. + return "", syserror.ESRCH + } + defer root.DecRef(ctx) + + vfsObj := cwd.Mount().Filesystem().VirtualFilesystem() + name, _ := vfsObj.PathnameWithDeleted(ctx, root, cwd) + return name, nil +} + +// Getlink implements kernfs.Inode.Getlink. +func (s *cwdSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { + if !kernel.ContextCanTrace(ctx, s.task, false) { + return vfs.VirtualDentry{}, "", syserror.EACCES + } + if err := checkTaskState(s.task); err != nil { + return vfs.VirtualDentry{}, "", err + } + cwd := s.task.FSContext().WorkingDirectoryVFS2() + if !cwd.Ok() { + // It could have raced with process deletion. + return vfs.VirtualDentry{}, "", syserror.ESRCH + } + return cwd, "", nil } // mountInfoData is used to implement /proc/[pid]/mountinfo. diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index f693f9060..6975af5a7 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -67,6 +67,7 @@ var ( taskStaticFiles = map[string]testutil.DirentType{ "auxv": linux.DT_REG, "cgroup": linux.DT_REG, + "cwd": linux.DT_LNK, "cmdline": linux.DT_REG, "comm": linux.DT_REG, "environ": linux.DT_REG, diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD index 067c1657f..adb610213 100644 --- a/pkg/sentry/fsimpl/signalfd/BUILD +++ b/pkg/sentry/fsimpl/signalfd/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/kernel", "//pkg/sentry/vfs", diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index bf11b425a..10f1452ef 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -16,7 +16,6 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -95,8 +94,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen } // Copy out the signal info using the specified format. - var buf [128]byte - binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + infoNative := linux.SignalfdSiginfo{ Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, @@ -105,9 +103,13 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), - }) - n, err := dst.CopyOut(ctx, buf[:]) - return int64(n), err + } + n, err := infoNative.WriteTo(dst.Writer(ctx)) + if err == usermem.ErrEndOfIOSequence { + // Partial copy-out ok. + err = nil + } + return n, err } // Readiness implements waiter.Waitable.Readiness. diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index b75d70ae6..1a6749e53 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -104,7 +104,7 @@ func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro func (fd *kcovFD) Release(ctx context.Context) { // kcov instances have reference counts in Linux, but this seems sufficient // for our purposes. - fd.kcov.Reset() + fd.kcov.Clear() } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index bc8e38431..0ca750281 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") licenses(["notice"]) @@ -26,3 +26,22 @@ go_library( "//pkg/usermem", ], ) + +go_test( + name = "verity_test", + srcs = [ + "verity_test.go", + ], + library = ":verity", + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/fspath", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/contexttest", + "//pkg/sentry/vfs", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 26b117ca4..a560b0797 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -693,22 +693,24 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // be called if a verity FD is created successfully. defer merkleWriter.DecRef(ctx) - parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.parent.lowerMerkleVD, - Start: d.parent.lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_WRONLY | linux.O_APPEND, - }) - if err != nil { - if err == syserror.ENOENT { - parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + if d.parent != nil { + parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.parent.lowerMerkleVD, + Start: d.parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + } + return nil, err } - return nil, err + // parentMerkleWriter is cleaned up if any error occurs. IncRef + // will be called if a verity FD is created successfully. + defer parentMerkleWriter.DecRef(ctx) } - // parentMerkleWriter is cleaned up if any error occurs. IncRef - // will be called if a verity FD is created successfully. - defer parentMerkleWriter.DecRef(ctx) } fd := &fileDescription{ diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 3129f290d..fc5eabbca 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -558,7 +558,7 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, // enableVerity enables verity features on fd by generating a Merkle tree file // and stores its root hash in its parent directory's Merkle tree. -func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) { if !fd.d.fs.allowRuntimeEnable { return 0, syserror.EPERM } @@ -568,7 +568,11 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg verityMu.Lock() defer verityMu.Unlock() - if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil { + // In allowRuntimeEnable mode, the underlying fd and read/write fd for + // the Merkle tree file should have all been initialized. For any file + // or directory other than the root, the parent Merkle tree file should + // have also been initialized. + if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") } @@ -577,26 +581,28 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg return 0, err } - stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{}) - if err != nil { - return 0, err - } + if fd.parentMerkleWriter != nil { + stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return 0, err + } - // Write the root hash of fd to the parent directory's Merkle tree - // file, as it should be part of the parent Merkle tree data. - // parentMerkleWriter is open with O_APPEND, so it should write - // directly to the end of the file. - if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil { - return 0, err - } + // Write the root hash of fd to the parent directory's Merkle + // tree file, as it should be part of the parent Merkle tree + // data. parentMerkleWriter is open with O_APPEND, so it + // should write directly to the end of the file. + if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil { + return 0, err + } - // Record the offset of the root hash of fd in parent directory's - // Merkle tree file. - if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ - Name: merkleOffsetInParentXattr, - Value: strconv.Itoa(int(stat.Size)), - }); err != nil { - return 0, err + // Record the offset of the root hash of fd in parent directory's + // Merkle tree file. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleOffsetInParentXattr, + Value: strconv.Itoa(int(stat.Size)), + }); err != nil { + return 0, err + } } // Record the size of the data being hashed for fd. @@ -610,7 +616,45 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg return 0, nil } -func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { +// measureVerity returns the root hash of fd, saved in args[2]. +func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + var metadata linux.DigestMetadata + + // If allowRuntimeEnable is true, an empty fd.d.rootHash indicates that + // verity is not enabled for the file. If allowRuntimeEnable is false, + // this is an integrity violation because all files should have verity + // enabled, in which case fd.d.rootHash should be set. + if len(fd.d.rootHash) == 0 { + if fd.d.fs.allowRuntimeEnable { + return 0, syserror.ENODATA + } + return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no root hash found") + } + + // The first part of VerityDigest is the metadata. + if _, err := metadata.CopyIn(t, verityDigest); err != nil { + return 0, err + } + if metadata.DigestSize < uint16(len(fd.d.rootHash)) { + return 0, syserror.EOVERFLOW + } + + // Populate the output digest size, since DigestSize is both input and + // output. + metadata.DigestSize = uint16(len(fd.d.rootHash)) + + // First copy the metadata. + if _, err := metadata.CopyOut(t, verityDigest); err != nil { + return 0, err + } + + // Now copy the root hash bytes to the memory after metadata. + _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.rootHash) + return 0, err +} + +func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) { f := int32(0) // All enabled files should store a root hash. This flag is not settable @@ -620,8 +664,7 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar } t := kernel.TaskFromContext(ctx) - addr := args[2].Pointer() - _, err := primitive.CopyInt32Out(t, addr, f) + _, err := primitive.CopyInt32Out(t, flags, f) return 0, err } @@ -629,11 +672,15 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := args[1].Uint(); cmd { case linux.FS_IOC_ENABLE_VERITY: - return fd.enableVerity(ctx, uio, args) + return fd.enableVerity(ctx, uio) + case linux.FS_IOC_MEASURE_VERITY: + return fd.measureVerity(ctx, uio, args[2].Pointer()) case linux.FS_IOC_GETFLAGS: - return fd.getFlags(ctx, uio, args) + return fd.verityFlags(ctx, uio, args[2].Pointer()) default: - return fd.lowerFD.Ioctl(ctx, uio, args) + // TODO(b/169682228): Investigate which ioctl commands should + // be allowed. + return 0, syserror.ENOSYS } } diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go new file mode 100644 index 000000000..8bcc14131 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -0,0 +1,429 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "fmt" + "io" + "math/rand" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// rootMerkleFilename is the name of the root Merkle tree file. +const rootMerkleFilename = "root.verity" + +// maxDataSize is the maximum data size written to the file for test. +const maxDataSize = 100000 + +// newVerityRoot creates a new verity mount, and returns the root. The +// underlying file system is tmpfs. If the error is not nil, then cleanup +// should be called when the root is no longer needed. +func newVerityRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) { + rand.Seed(time.Now().UnixNano()) + vfsObj := &vfs.VirtualFilesystem{} + if err := vfsObj.Init(ctx); err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) + } + + vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: InternalFilesystemOptions{ + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + AllowRuntimeEnable: true, + NoCrashOnVerificationFailure: true, + }, + }, + }) + if err != nil { + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) + } + root := mntns.Root() + return vfsObj, root, func() { + root.DecRef(ctx) + mntns.DecRef(ctx) + }, nil +} + +// newFileFD creates a new file in the verity mount, and returns the FD. The FD +// points to a file that has random data generated. +func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { + creds := auth.CredentialsFromContext(ctx) + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + + // Create the file in the underlying file system. + lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, + Mode: linux.ModeRegular | mode, + }) + if err != nil { + return nil, 0, err + } + + // Generate random data to be written to the file. + dataSize := rand.Intn(maxDataSize) + 1 + data := make([]byte, dataSize) + rand.Read(data) + + // Write directly to the underlying FD, since verity FD is read-only. + n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) + if err != nil { + return nil, 0, err + } + + if n != int64(len(data)) { + return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data)) + } + + lowerFD.DecRef(ctx) + + // Now open the verity file descriptor. + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular | mode, + }) + return fd, dataSize, err +} + +// corruptRandomBit randomly flips a bit in the file represented by fd. +func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { + // Flip a random bit in the underlying file. + randomPos := int64(rand.Intn(size)) + byteToModify := make([]byte, 1) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { + return fmt.Errorf("lowerFD.PRead: %v", err) + } + byteToModify[0] ^= 1 + if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil { + return fmt.Errorf("lowerFD.PWrite: %v", err) + } + return nil +} + +// TestOpen ensures that when a file is created, the corresponding Merkle tree +// file and the root Merkle tree file exist. +func TestOpen(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + rootMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } +} + +// TestUntouchedFileSucceeds ensures that read from an untouched verity file +// succeeds after enabling verity for it. +func TestReadUntouchedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } +} + +// TestReopenUntouchedFileSucceeds ensures that reopen an untouched verity file +// succeeds after enabling verity for it. +func TestReopenUntouchedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Ensure reopening the verity enabled file succeeds. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } +} + +// TestModifiedFileFails ensures that read from a modified verity file fails. +func TestModifiedFileFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded with modified file") + } +} + +// TestModifiedMerkleFails ensures that read from a verity file fails if the +// corresponding Merkle tree file is modified. +func TestModifiedMerkleFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + merkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + fmt.Println(buf) + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } +} + +// TestModifiedParentMerkleFails ensures that open a verity enabled file in a +// verity enabled directory fails if the hashes related to the target file in +// the parent Merkle tree file is modified. +func TestModifiedParentMerkleFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, cleanup, err := newVerityRoot(ctx) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + defer cleanup() + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Enable verity on the parent directory. + parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD + + parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: parentLowerMerkleVD, + Start: parentLowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the parent Merkle tree file. + // This parent directory contains only one child, so any random + // modification in the parent Merkle tree should cause verification + // failure when opening the child file. + stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + parentMerkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + parentLowerMerkleFD.DecRef(ctx) + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err == nil { + t.Errorf("OpenAt file with modified parent Merkle succeeded") + } +} diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 61c78569d..300b7ccce 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -7,11 +7,14 @@ go_library( srcs = [ "cgroup.go", "hostmm.go", + "membarrier.go", ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", "//pkg/fd", "//pkg/log", "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/hostmm/membarrier.go b/pkg/sentry/hostmm/membarrier.go new file mode 100644 index 000000000..4468d75f1 --- /dev/null +++ b/pkg/sentry/hostmm/membarrier.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostmm + +import ( + "syscall" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/log" +) + +var ( + haveMembarrierGlobal = false + haveMembarrierPrivateExpedited = false +) + +func init() { + supported, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_QUERY, 0 /* flags */, 0 /* unused */) + if e != 0 { + if e != syscall.ENOSYS { + log.Warningf("membarrier(MEMBARRIER_CMD_QUERY) failed: %s", e.Error()) + } + return + } + // We don't use MEMBARRIER_CMD_GLOBAL_EXPEDITED because this sends IPIs to + // all CPUs running tasks that have previously invoked + // MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, which presents a DOS risk. + // (MEMBARRIER_CMD_GLOBAL is synchronize_rcu(), i.e. it waits for an RCU + // grace period to elapse without bothering other CPUs. + // MEMBARRIER_CMD_PRIVATE_EXPEDITED sends IPIs only to CPUs running tasks + // sharing the caller's MM.) + if supported&linux.MEMBARRIER_CMD_GLOBAL != 0 { + haveMembarrierGlobal = true + } + if req := uintptr(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED | linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED); supported&req == req { + if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 { + log.Warningf("membarrier(MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED) failed: %s", e.Error()) + } else { + haveMembarrierPrivateExpedited = true + } + } +} + +// HaveGlobalMemoryBarrier returns true if GlobalMemoryBarrier is supported. +func HaveGlobalMemoryBarrier() bool { + return haveMembarrierGlobal +} + +// GlobalMemoryBarrier blocks until "all running threads [in the host OS] have +// passed through a state where all memory accesses to user-space addresses +// match program order between entry to and return from [GlobalMemoryBarrier]", +// as for membarrier(2). +// +// Preconditions: HaveGlobalMemoryBarrier() == true. +func GlobalMemoryBarrier() error { + if _, _, e := syscall.Syscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_GLOBAL, 0 /* flags */, 0 /* unused */); e != 0 { + return e + } + return nil +} + +// HaveProcessMemoryBarrier returns true if ProcessMemoryBarrier is supported. +func HaveProcessMemoryBarrier() bool { + return haveMembarrierPrivateExpedited +} + +// ProcessMemoryBarrier is equivalent to GlobalMemoryBarrier, but only +// synchronizes with threads sharing a virtual address space (from the host OS' +// perspective) with the calling thread. +// +// Preconditions: HaveProcessMemoryBarrier() == true. +func ProcessMemoryBarrier() error { + if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 { + return e + } + return nil +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a43c549f1..5de70aecb 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -69,8 +69,8 @@ go_template_instance( prefix = "socket", template = "//pkg/ilist:generic_list", types = { - "Element": "*SocketEntry", - "Linker": "*SocketEntry", + "Element": "*SocketRecordVFS1", + "Linker": "*SocketRecordVFS1", }, ) @@ -204,7 +204,6 @@ go_library( "//pkg/abi", "//pkg/abi/linux", "//pkg/amutex", - "//pkg/binary", "//pkg/bits", "//pkg/bpf", "//pkg/context", diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go index d3e76ca7b..060c056df 100644 --- a/pkg/sentry/kernel/kcov.go +++ b/pkg/sentry/kernel/kcov.go @@ -89,7 +89,7 @@ func (kcov *Kcov) TaskWork(t *Task) { kcov.mu.Lock() defer kcov.mu.Unlock() - if kcov.mode != linux.KCOV_TRACE_PC { + if kcov.mode != linux.KCOV_MODE_TRACE_PC { return } @@ -146,7 +146,7 @@ func (kcov *Kcov) InitTrace(size uint64) error { } // EnableTrace performs the KCOV_ENABLE_TRACE ioctl. -func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error { +func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error { t := TaskFromContext(ctx) if t == nil { panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine") @@ -160,9 +160,9 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error { return syserror.EINVAL } - switch traceMode { + switch traceKind { case linux.KCOV_TRACE_PC: - kcov.mode = traceMode + kcov.mode = linux.KCOV_MODE_TRACE_PC case linux.KCOV_TRACE_CMP: // We do not support KCOV_MODE_TRACE_CMP. return syserror.ENOTSUP @@ -175,6 +175,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error { } kcov.owningTask = t + t.SetKcov(kcov) t.RegisterWork(kcov) // Clear existing coverage data; the task expects to read only coverage data @@ -196,26 +197,35 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error { if t != kcov.owningTask { return syserror.EINVAL } - kcov.owningTask = nil kcov.mode = linux.KCOV_MODE_INIT - kcov.resetLocked() + kcov.owningTask = nil + kcov.mappable = nil return nil } -// Reset is called when the owning task exits. -func (kcov *Kcov) Reset() { +// Clear resets the mode and clears the owning task and memory mapping for kcov. +// It is called when the fd corresponding to kcov is closed. Note that the mode +// needs to be set so that the next call to kcov.TaskWork() will exit early. +func (kcov *Kcov) Clear() { kcov.mu.Lock() - kcov.resetLocked() + kcov.clearLocked() kcov.mu.Unlock() } -// The kcov instance is reset when the owning task exits or when tracing is -// disabled. -func (kcov *Kcov) resetLocked() { +func (kcov *Kcov) clearLocked() { + kcov.mode = linux.KCOV_MODE_INIT kcov.owningTask = nil - if kcov.mappable != nil { - kcov.mappable = nil - } + kcov.mappable = nil +} + +// OnTaskExit is called when the owning task exits. It is similar to +// kcov.Clear(), except the memory mapping is not cleared, so that the same +// mapping can be used in the future if kcov is enabled again by another task. +func (kcov *Kcov) OnTaskExit() { + kcov.mu.Lock() + kcov.mode = linux.KCOV_MODE_INIT + kcov.owningTask = nil + kcov.mu.Unlock() } // ConfigureMMap is called by the vfs.FileDescription for this kcov instance to diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 08bb5bd12..d6c21adb7 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -220,13 +220,18 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // sockets is the list of all network sockets the system. Protected by - // extMu. + // sockets is the list of all network sockets in the system. + // Protected by extMu. + // TODO(gvisor.dev/issue/1624): Only used by VFS1. sockets socketList - // nextSocketEntry is the next entry number to use in sockets. Protected + // socketsVFS2 records all network sockets in the system. Protected by + // extMu. + socketsVFS2 map[*vfs.FileDescription]*SocketRecord + + // nextSocketRecord is the next entry number to use in sockets. Protected // by extMu. - nextSocketEntry uint64 + nextSocketRecord uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { return fmt.Errorf("failed to create sockfs mount: %v", err) } k.socketMount = socketMount + + k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) } return nil @@ -1512,20 +1519,27 @@ func (k *Kernel) SupervisorContext() context.Context { } } -// SocketEntry represents a socket recorded in Kernel.sockets. It implements +// SocketRecord represents a socket recorded in Kernel.socketsVFS2. +// +// +stateify savable +type SocketRecord struct { + k *Kernel + Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1. + SockVFS2 *vfs.FileDescription // Only used by VFS2. + ID uint64 // Socket table entry number. +} + +// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type SocketEntry struct { +type SocketRecordVFS1 struct { socketEntry - k *Kernel - Sock *refs.WeakRef - SockVFS2 *vfs.FileDescription - ID uint64 // Socket table entry number. + SocketRecord } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *SocketEntry) WeakRefGone(context.Context) { +func (s *SocketRecordVFS1) WeakRefGone(context.Context) { s.k.extMu.Lock() s.k.sockets.Remove(s) s.k.extMu.Unlock() @@ -1536,9 +1550,14 @@ func (s *SocketEntry) WeakRefGone(context.Context) { // Precondition: Caller must hold a reference to sock. func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{k: k, ID: id} + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecordVFS1{ + SocketRecord: SocketRecord{ + k: k, + ID: id, + }, + } s.Sock = refs.NewWeakRef(sock, s) k.sockets.PushBack(s) k.extMu.Unlock() @@ -1550,29 +1569,45 @@ func (k *Kernel) RecordSocket(sock *fs.File) { // Precondition: Caller must hold a reference to sock. // // Note that the socket table will not hold a reference on the -// vfs.FileDescription, because we do not support weak refs on VFS2 files. +// vfs.FileDescription. func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{ + if _, ok := k.socketsVFS2[sock]; ok { + panic(fmt.Sprintf("Socket %p added twice", sock)) + } + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecord{ k: k, ID: id, SockVFS2: sock, } - k.sockets.PushBack(s) + k.socketsVFS2[sock] = s + k.extMu.Unlock() +} + +// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table. +func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) { + k.extMu.Lock() + delete(k.socketsVFS2, sock) k.extMu.Unlock() } // ListSockets returns a snapshot of all sockets. // -// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef() +// Callers of ListSockets() in VFS2 should use SocketRecord.SockVFS2.TryIncRef() // to get a reference on a socket in the table. -func (k *Kernel) ListSockets() []*SocketEntry { +func (k *Kernel) ListSockets() []*SocketRecord { k.extMu.Lock() - var socks []*SocketEntry - for s := k.sockets.Front(); s != nil; s = s.Next() { - socks = append(socks, s) + var socks []*SocketRecord + if VFS2Enabled { + for _, s := range k.socketsVFS2 { + socks = append(socks, s) + } + } else { + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, &s.SocketRecord) + } } k.extMu.Unlock() return socks diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 449643118..99134e634 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/amutex", "//pkg/buffer", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 6d58b682f..f665920cb 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -145,9 +146,14 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume v = math.MaxInt32 // Silently truncate. } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + iocc := primitive.IOCopyContext{ + IO: io, + Ctx: ctx, + Opts: usermem.IOOpts{ + AddressSpaceActive: true, + }, + } + _, err := primitive.CopyInt32Out(&iocc, args[2].Pointer(), int32(v)) return 0, err default: return 0, syscall.ENOTTY diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index c38c5a40c..387edfa91 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -18,7 +18,6 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" @@ -27,25 +26,18 @@ import ( const maxSyscallFilterInstructions = 1 << 15 -// seccompData is equivalent to struct seccomp_data, which contains the data -// passed to seccomp-bpf filters. -type seccompData struct { - // nr is the system call number. - nr int32 - - // arch is an AUDIT_ARCH_* value indicating the system call convention. - arch uint32 - - // instructionPointer is the value of the instruction pointer at the time - // of the system call. - instructionPointer uint64 - - // args contains the first 6 system call arguments. - args [6]uint64 -} - -func (d *seccompData) asBPFInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder} +// dataAsBPFInput returns a serialized BPF program, only valid on the current task +// goroutine. +// +// Note: this is called for every syscall, which is a very hot path. +func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input { + buf := t.CopyScratchBuffer(d.SizeBytes()) + d.MarshalUnsafe(buf) + return bpf.InputBytes{ + Data: buf, + // Go-marshal always uses the native byte order. + Order: usermem.ByteOrder, + } } func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo { @@ -112,20 +104,20 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u } func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 { - data := seccompData{ - nr: sysno, - arch: t.tc.st.AuditNumber, - instructionPointer: uint64(ip), + data := linux.SeccompData{ + Nr: sysno, + Arch: t.tc.st.AuditNumber, + InstructionPointer: uint64(ip), } // data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so // we can't do any slicing tricks or even use copy/append here. for i, arg := range args { - if i >= len(data.args) { + if i >= len(data.Args) { break } - data.args[i] = arg.Uint64() + data.Args[i] = arg.Uint64() } - input := data.asBPFInput() + input := dataAsBPFInput(t, &data) ret := uint32(linux.SECCOMP_RET_ALLOW) f := t.syscallFilters.Load() diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 3eb78e91b..76d472292 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index b07e1c1bd..78f718cfe 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -17,7 +17,6 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" @@ -103,8 +102,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } // Copy out the signal info using the specified format. - var buf [128]byte - binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + infoNative := linux.SignalfdSiginfo{ Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, @@ -113,9 +111,13 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), - }) - n, err := dst.CopyOut(ctx, buf[:]) - return int64(n), err + } + n, err := infoNative.WriteTo(dst.Writer(ctx)) + if err == usermem.ErrEndOfIOSequence { + // Partial copy-out ok. + err = nil + } + return n, err } // Readiness implements waiter.Waitable.Readiness. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index a436610c9..f796e0fa3 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -917,7 +917,7 @@ func (t *Task) SetKcov(k *Kcov) { // ResetKcov clears the kcov instance associated with t. func (t *Task) ResetKcov() { if t.kcov != nil { - t.kcov.Reset() + t.kcov.OnTaskExit() t.kcov = nil } } diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go index 9fa528384..d1136461a 100644 --- a/pkg/sentry/kernel/task_context.go +++ b/pkg/sentry/kernel/task_context.go @@ -126,7 +126,11 @@ func (t *Task) SyscallTable() *SyscallTable { // Preconditions: The caller must be running on the task goroutine, or t.mu // must be locked. func (t *Task) Stack() *arch.Stack { - return &arch.Stack{t.Arch(), t.MemoryManager(), usermem.Addr(t.Arch().Stack())} + return &arch.Stack{ + Arch: t.Arch(), + IO: t.MemoryManager(), + Bottom: usermem.Addr(t.Arch().Stack()), + } } // LoadTaskImage loads a specified file into a new TaskContext. diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index feaa38596..ebdb83061 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -259,7 +259,11 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // Set up the signal handler. If we have a saved signal mask, the signal // handler should run with the current mask, but sigreturn should restore // the saved one. - st := &arch.Stack{t.Arch(), mm, sp} + st := &arch.Stack{ + Arch: t.Arch(), + IO: mm, + Bottom: sp, + } mask := t.signalMask if t.haveSavedSignalMask { mask = t.savedSignalMask diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index e44a139b3..9bc452e67 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -17,7 +17,6 @@ package kernel import ( "fmt" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -28,6 +27,8 @@ import ( // // They are exposed to the VDSO via a parameter page managed by VDSOParamPage, // which also includes a sequence counter. +// +// +marshal type vdsoParams struct { monotonicReady uint64 monotonicBaseCycles int64 @@ -68,6 +69,13 @@ type VDSOParamPage struct { // checked in state_test_util tests, causing this field to change across // save / restore. seq uint64 + + // copyScratchBuffer is a temporary buffer used to marshal the params before + // copying it to the real parameter page. The parameter page is typically + // updated at a moderate frequency of ~O(seconds) throughout the lifetime of + // the sentry, so reusing this buffer is a good tradeoff between memory + // usage and the cost of allocation. + copyScratchBuffer []byte } // NewVDSOParamPage returns a VDSOParamPage. @@ -79,7 +87,11 @@ type VDSOParamPage struct { // * VDSOParamPage must be the only writer to fr. // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { - return &VDSOParamPage{mfp: mfp, fr: fr} + return &VDSOParamPage{ + mfp: mfp, + fr: fr, + copyScratchBuffer: make([]byte, (*vdsoParams)(nil).SizeBytes()), + } } // access returns a mapping of the param page. @@ -133,7 +145,8 @@ func (v *VDSOParamPage) Write(f func() vdsoParams) error { // Get the new params. p := f() - buf := binary.Marshal(nil, usermem.ByteOrder, p) + buf := v.copyScratchBuffer[:p.SizeBytes()] + p.MarshalUnsafe(buf) // Skip the sequence counter. if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil { diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 15c88aa7c..c69b62db9 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -122,7 +122,7 @@ func allocStack(ctx context.Context, m *mm.MemoryManager, a arch.Context) (*arch if err != nil { return nil, err } - return &arch.Stack{a, m, ar.End}, nil + return &arch.Stack{Arch: a, IO: m, Bottom: ar.End}, nil } const ( @@ -247,20 +247,20 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V } // Push the original filename to the stack, for AT_EXECFN. - execfn, err := stack.Push(args.Filename) - if err != nil { + if _, err := stack.PushNullTerminatedByteSlice([]byte(args.Filename)); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push exec filename: %v", err), syserr.FromError(err).ToLinux()) } + execfn := stack.Bottom // Push 16 random bytes on the stack which AT_RANDOM will point to. var b [16]byte if _, err := rand.Read(b[:]); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to read random bytes: %v", err), syserr.FromError(err).ToLinux()) } - random, err := stack.Push(b) - if err != nil { + if _, err = stack.PushNullTerminatedByteSlice(b[:]); err != nil { return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to push random bytes: %v", err), syserr.FromError(err).ToLinux()) } + random := stack.Bottom c := auth.CredentialsFromContext(ctx) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 8c9f11cce..43567b92c 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -235,6 +235,14 @@ type MemoryManager struct { // vdsoSigReturnAddr is the address of 'vdso_sigreturn'. vdsoSigReturnAddr uint64 + + // membarrierPrivateEnabled is non-zero if EnableMembarrierPrivate has + // previously been called. Since, as of this writing, + // MEMBARRIER_CMD_PRIVATE_EXPEDITED is implemented as a global memory + // barrier, membarrierPrivateEnabled has no other effect. + // + // membarrierPrivateEnabled is accessed using atomic memory operations. + membarrierPrivateEnabled uint32 } // vma represents a virtual memory area. diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index a2555ba1a..0a66b1cdd 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -17,6 +17,7 @@ package mm import ( "fmt" mrand "math/rand" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -1274,3 +1275,15 @@ func (mm *MemoryManager) VirtualDataSize() uint64 { defer mm.mappingMu.RUnlock() return mm.dataAS } + +// EnableMembarrierPrivate causes future calls to IsMembarrierPrivateEnabled to +// return true. +func (mm *MemoryManager) EnableMembarrierPrivate() { + atomic.StoreUint32(&mm.membarrierPrivateEnabled, 1) +} + +// IsMembarrierPrivateEnabled returns true if mm.EnableMembarrierPrivate() has +// previously been called. +func (mm *MemoryManager) IsMembarrierPrivateEnabled() bool { + return atomic.LoadUint32(&mm.membarrierPrivateEnabled) != 0 +} diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 209b28053..db7d55ef2 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -15,6 +15,7 @@ go_library( "//pkg/context", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/hostmm", "//pkg/sentry/memmap", "//pkg/usermem", ], diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 3970dd81d..dd2bbeb12 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -9,12 +9,12 @@ go_library( "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", - "bluepill_amd64.s", "bluepill_amd64_unsafe.go", "bluepill_arm64.go", "bluepill_arm64.s", "bluepill_arm64_unsafe.go", "bluepill_fault.go", + "bluepill_impl_amd64.s", "bluepill_unsafe.go", "context.go", "filters_amd64.go", @@ -56,6 +56,7 @@ go_library( "//pkg/sentry/time", "//pkg/sync", "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", ], ) @@ -78,6 +79,15 @@ go_test( "//pkg/sentry/platform/kvm/testutil", "//pkg/sentry/platform/ring0", "//pkg/sentry/platform/ring0/pagetables", + "//pkg/sentry/time", "//pkg/usermem", ], ) + +genrule( + name = "bluepill_impl_amd64", + srcs = ["bluepill_amd64.s"], + outs = ["bluepill_impl_amd64.s"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index 2bc34a435..025ea93b5 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -19,11 +19,6 @@ // This is guaranteed to be zero. #define VCPU_CPU 0x0 -// CPU_SELF is the self reference in ring0's percpu. -// -// This is guaranteed to be zero. -#define CPU_SELF 0x0 - // Context offsets. // // Only limited use of the context is done in the assembly stub below, most is @@ -44,7 +39,7 @@ begin: LEAQ VCPU_CPU(AX), BX BYTE CLI; check_vcpu: - MOVQ CPU_SELF(GS), CX + MOVQ ENTRY_CPU_SELF(GS), CX CMPQ BX, CX JE right_vCPU wrong_vcpu: diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 03a98512e..0a54dd30d 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -83,5 +83,34 @@ func bluepillStopGuest(c *vCPU) { // //go:nosplit func bluepillReadyStopGuest(c *vCPU) bool { - return c.runData.readyForInterruptInjection != 0 + if c.runData.readyForInterruptInjection == 0 { + return false + } + + if c.runData.ifFlag == 0 { + // This is impossible if readyForInterruptInjection is 1. + throw("interrupts are disabled") + } + + // Disable interrupts if we are in the kernel space. + // + // When the Sentry switches into the kernel mode, it disables + // interrupts. But when goruntime switches on a goroutine which has + // been saved in the host mode, it restores flags and this enables + // interrupts. See the comment of UserFlagsSet for more details. + uregs := userRegs{} + err := c.getUserRegisters(&uregs) + if err != 0 { + throw("failed to get user registers") + } + + if ring0.IsKernelFlags(uregs.RFLAGS) { + uregs.RFLAGS &^= ring0.KernelFlagsClear + err = c.setUserRegisters(&uregs) + if err != 0 { + throw("failed to set user registers") + } + return false + } + return true } diff --git a/pkg/sentry/platform/kvm/filters_amd64.go b/pkg/sentry/platform/kvm/filters_amd64.go index 7d949f1dd..d3d216aa5 100644 --- a/pkg/sentry/platform/kvm/filters_amd64.go +++ b/pkg/sentry/platform/kvm/filters_amd64.go @@ -17,14 +17,23 @@ package kvm import ( "syscall" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" ) // SyscallFilters returns syscalls made exclusively by the KVM platform. func (*KVM) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - syscall.SYS_ARCH_PRCTL: {}, - syscall.SYS_IOCTL: {}, + syscall.SYS_ARCH_PRCTL: {}, + syscall.SYS_IOCTL: {}, + unix.SYS_MEMBARRIER: []seccomp.Rule{ + { + seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED), + seccomp.EqualTo(0), + }, + }, syscall.SYS_MMAP: {}, syscall.SYS_RT_SIGSUSPEND: {}, syscall.SYS_RT_SIGTIMEDWAIT: {}, diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go index 9245d07c2..21abc2a3d 100644 --- a/pkg/sentry/platform/kvm/filters_arm64.go +++ b/pkg/sentry/platform/kvm/filters_arm64.go @@ -17,13 +17,22 @@ package kvm import ( "syscall" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" ) // SyscallFilters returns syscalls made exclusively by the KVM platform. func (*KVM) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - syscall.SYS_IOCTL: {}, + syscall.SYS_IOCTL: {}, + unix.SYS_MEMBARRIER: []seccomp.Rule{ + { + seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED), + seccomp.EqualTo(0), + }, + }, syscall.SYS_MMAP: {}, syscall.SYS_RT_SIGSUSPEND: {}, syscall.SYS_RT_SIGTIMEDWAIT: {}, diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index ae813e24e..dd45ad10b 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -63,6 +63,9 @@ type runData struct { type KVM struct { platform.NoCPUPreemptionDetection + // KVM never changes mm_structs. + platform.UseHostProcessMemoryBarrier + // machine is the backing VM. machine *machine } @@ -156,15 +159,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. pageTables := pagetables.New(newAllocator()) - applyPhysicalRegions(func(pr physicalRegion) bool { - // Map the kernel in the upper half. - pageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. - }) + k.machine.mapUpperHalf(pageTables) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 5c4b18899..6abaa21c4 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -26,12 +26,16 @@ const ( _KVM_RUN = 0xae80 _KVM_NMI = 0xae9a _KVM_CHECK_EXTENSION = 0xae03 + _KVM_GET_TSC_KHZ = 0xaea3 + _KVM_SET_TSC_KHZ = 0xaea2 _KVM_INTERRUPT = 0x4004ae86 _KVM_SET_MSRS = 0x4008ae89 _KVM_SET_USER_MEMORY_REGION = 0x4020ae46 _KVM_SET_REGS = 0x4090ae82 _KVM_SET_SREGS = 0x4138ae84 + _KVM_GET_MSRS = 0xc008ae88 _KVM_GET_REGS = 0x8090ae81 + _KVM_GET_SREGS = 0x8138ae83 _KVM_GET_SUPPORTED_CPUID = 0xc008ae05 _KVM_SET_CPUID2 = 0x4008ae90 _KVM_SET_SIGNAL_MASK = 0x4004ae8b @@ -79,11 +83,14 @@ const ( ) // KVM hypercall list. +// // Canonical list of hypercalls supported. const ( // On amd64, it uses 'HLT' to leave the guest. + // // Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest. - // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now. + // + // _KVM_HYPERCALL_VMEXIT is only used on arm64 for now. _KVM_HYPERCALL_VMEXIT int = iota _KVM_HYPERCALL_MAX ) diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 45b3180f1..2e12470aa 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) @@ -442,6 +443,22 @@ func TestWrongVCPU(t *testing.T) { }) } +func TestRdtsc(t *testing.T) { + var i int // Iteration count. + kvmTest(t, nil, func(c *vCPU) bool { + start := ktime.Rdtsc() + bluepill(c) + guest := ktime.Rdtsc() + redpill() + end := ktime.Rdtsc() + if start > guest || guest > end { + t.Errorf("inconsistent time: start=%d, guest=%d, end=%d", start, guest, end) + } + i++ + return i < 100 + }) +} + func BenchmarkApplicationSyscall(b *testing.B) { var ( i int // Iteration includes machine.Get() / machine.Put(). diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 372a4cbd7..75da253c5 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -155,7 +155,7 @@ func (m *machine) newVCPU() *vCPU { fd: int(fd), machine: m, } - c.CPU.Init(&m.kernel, c) + c.CPU.Init(&m.kernel, c.id, c) m.vCPUsByID[c.id] = c // Ensure the signal mask is correct. @@ -183,9 +183,6 @@ func newMachine(vm int) (*machine, error) { // Create the machine. m := &machine{fd: vm} m.available.L = &m.mu - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }) // Pull the maximum vCPUs. maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) @@ -197,6 +194,9 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) + m.kernel.Init(ring0.KernelOpts{ + PageTables: pagetables.New(newAllocator()), + }, m.maxVCPUs) // Pull the maximum slots. maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) @@ -219,15 +219,9 @@ func newMachine(vm int) (*machine, error) { pagetables.MapOpts{AccessType: usermem.AnyAccess}, pr.physical) - // And keep everything in the upper half. - m.kernel.PageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. }) + m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion @@ -365,6 +359,11 @@ func (m *machine) Destroy() { // Get gets an available vCPU. // // This will return with the OS thread locked. +// +// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points +// to the vCPU in which the OS thread TID is running. So if Get() returns with +// the corrent context in guest, the vCPU of it must be the same as what +// Get() returns. func (m *machine) Get() *vCPU { m.mu.RLock() runtime.LockOSThread() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index acc823ba6..451953008 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -18,14 +18,17 @@ package kvm import ( "fmt" + "math/big" "reflect" "runtime/debug" "syscall" + "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) @@ -144,6 +147,7 @@ func (c *vCPU) initArchState() error { // Set the entrypoint for the kernel. kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer()) kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer()) + kernelUserRegs.RSP = c.StackTop() kernelUserRegs.RFLAGS = ring0.KernelFlagsSet // Set the system registers. @@ -152,8 +156,8 @@ func (c *vCPU) initArchState() error { } // Set the user registers. - if err := c.setUserRegisters(&kernelUserRegs); err != nil { - return err + if errno := c.setUserRegisters(&kernelUserRegs); errno != 0 { + return fmt.Errorf("error setting user registers: %v", errno) } // Allocate some floating point state save area for the local vCPU. @@ -166,6 +170,133 @@ func (c *vCPU) initArchState() error { return c.setSystemTime() } +// bitsForScaling returns the bits available for storing the fraction component +// of the TSC scaling ratio. This allows us to replicate the (bad) math done by +// the kernel below in scaledTSC, and ensure we can compute an exact zero +// offset in setSystemTime. +// +// These constants correspond to kvm_tsc_scaling_ratio_frac_bits. +var bitsForScaling = func() int64 { + fs := cpuid.HostFeatureSet() + if fs.Intel() { + return 48 // See vmx.c (kvm sources). + } else if fs.AMD() { + return 32 // See svm.c (svm sources). + } else { + return 63 // Unknown: theoretical maximum. + } +}() + +// scaledTSC returns the host TSC scaled by the given frequency. +// +// This assumes a current frequency of 1. We require only the unitless ratio of +// rawFreq to some current frequency. See setSystemTime for context. +// +// The kernel math guarantees that all bits of the multiplication and division +// will be correctly preserved and applied. However, it is not possible to +// actually store the ratio correctly. So we need to use the same schema in +// order to calculate the scaled frequency and get the same result. +// +// We can assume that the current frequency is (1), so we are calculating a +// strict inverse of this value. This simplifies this function considerably. +// +// Roughly, the returned value "scaledTSC" will have: +// scaledTSC/hostTSC == 1/rawFreq +// +//go:nosplit +func scaledTSC(rawFreq uintptr) int64 { + scale := int64(1 << bitsForScaling) + ratio := big.NewInt(scale / int64(rawFreq)) + ratio.Mul(ratio, big.NewInt(int64(ktime.Rdtsc()))) + ratio.Div(ratio, big.NewInt(scale)) + return ratio.Int64() +} + +// setSystemTime sets the vCPU to the system time. +func (c *vCPU) setSystemTime() error { + // First, scale down the clock frequency to the lowest value allowed by + // the API itself. How low we can go depends on the underlying + // hardware, but it is typically ~1/2^48 for Intel, ~1/2^32 for AMD. + // Even the lower bound here will take a 4GHz frequency down to 1Hz, + // meaning that everything should be able to handle a Khz setting of 1 + // with bits to spare. + // + // Note that reducing the clock does not typically require special + // capabilities as it is emulated in KVM. We don't actually use this + // capability, but it means that this method should be robust to + // different hardware configurations. + rawFreq, err := c.getTSCFreq() + if err != nil { + return c.setSystemTimeLegacy() + } + if err := c.setTSCFreq(1); err != nil { + return c.setSystemTimeLegacy() + } + + // Always restore the original frequency. + defer func() { + if err := c.setTSCFreq(rawFreq); err != nil { + panic(err.Error()) + } + }() + + // Attempt to set the system time in this compressed world. The + // calculation for offset normally looks like: + // + // offset = target_tsc - kvm_scale_tsc(vcpu, rdtsc()); + // + // So as long as the kvm_scale_tsc component is constant before and + // after the call to set the TSC value (and it is passes as the + // target_tsc), we will compute an offset value of zero. + // + // This is effectively cheating to make our "setSystemTime" call so + // unbelievably, incredibly fast that we do it "instantly" and all the + // calculations result in an offset of zero. + lastTSC := scaledTSC(rawFreq) + for { + if err := c.setTSC(uint64(lastTSC)); err != nil { + return err + } + nextTSC := scaledTSC(rawFreq) + if lastTSC == nextTSC { + return nil + } + lastTSC = nextTSC // Try again. + } +} + +// setSystemTimeLegacy calibrates and sets an approximate system time. +func (c *vCPU) setSystemTimeLegacy() error { + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Try to set the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + start := uint64(ktime.Rdtsc()) + if err := c.setTSC(start + (minimum / 2)); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && current <= upperThreshold { + return nil + } + } +} + // nonCanonical generates a canonical address return. // //go:nosplit @@ -345,3 +476,41 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } + +var execRegions = func() (regions []region) { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { + return + } + if vr.accessType.Execute { + regions = append(regions, vr.region) + } + }) + return +}() + +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + for _, r := range execRegions { + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossilbe translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) + } + for start, end := range m.kernel.EntryRegions() { + regionLen := end - start + physical, length, ok := translateToPhysical(start) + if !ok || length < regionLen { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|start), + regionLen, + pagetables.MapOpts{AccessType: usermem.ReadWrite}, + physical) + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 290f035dd..b430f92c6 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -23,7 +23,6 @@ import ( "unsafe" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/time" ) // loadSegments copies the current segments. @@ -61,91 +60,63 @@ func (c *vCPU) setCPUID() error { return nil } -// setSystemTime sets the TSC for the vCPU. +// getTSCFreq gets the TSC frequency. // -// This has to make the call many times in order to minimize the intrinsic -// error in the offset. Unfortunately KVM does not expose a relative offset via -// the API, so this is an approximation. We do this via an iterative algorithm. -// This has the advantage that it can generally deal with highly variable -// system call times and should converge on the correct offset. -func (c *vCPU) setSystemTime() error { - const ( - _MSR_IA32_TSC = 0x00000010 - calibrateTries = 10 - ) - registers := modelControlRegisters{ - nmsrs: 1, - } - registers.entries[0] = modelControlRegister{ - index: _MSR_IA32_TSC, +// If mustSucceed is true, then this function panics on error. +func (c *vCPU) getTSCFreq() (uintptr, error) { + rawFreq, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_TSC_KHZ, + 0 /* ignored */) + if errno != 0 { + return 0, errno } - target := uint64(^uint32(0)) - for done := 0; done < calibrateTries; { - start := uint64(time.Rdtsc()) - registers.entries[0].data = start + target - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_MSRS, - uintptr(unsafe.Pointer(®isters))); errno != 0 { - return fmt.Errorf("error setting system time: %v", errno) - } - // See if this is our new minimum call time. Note that this - // serves two functions: one, we make sure that we are - // accurately predicting the offset we need to set. Second, we - // don't want to do the final set on a slow call, which could - // produce a really bad result. So we only count attempts - // within +/- 6.25% of our minimum as an attempt. - end := uint64(time.Rdtsc()) - if end < start { - continue // Totally bogus. - } - half := (end - start) / 2 - if half < target { - target = half - } - if (half - target) < target/8 { - done++ - } + return rawFreq, nil +} + +// setTSCFreq sets the TSC frequency. +func (c *vCPU) setTSCFreq(freq uintptr) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_TSC_KHZ, + freq /* khz */); errno != 0 { + return fmt.Errorf("error setting TSC frequency: %v", errno) } return nil } -// setSignalMask sets the vCPU signal mask. -// -// This must be called prior to running the vCPU. -func (c *vCPU) setSignalMask() error { - // The layout of this structure implies that it will not necessarily be - // the same layout chosen by the Go compiler. It gets fudged here. - var data struct { - length uint32 - mask1 uint32 - mask2 uint32 - _ uint32 +// setTSC sets the TSC value. +func (c *vCPU) setTSC(value uint64) error { + const _MSR_IA32_TSC = 0x00000010 + registers := modelControlRegisters{ + nmsrs: 1, } - data.length = 8 // Fixed sigset size. - data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) - data.mask2 = ^uint32(bounceSignalMask >> 32) + registers.entries[0].index = _MSR_IA32_TSC + registers.entries[0].data = value if _, _, errno := syscall.RawSyscall( syscall.SYS_IOCTL, uintptr(c.fd), - _KVM_SET_SIGNAL_MASK, - uintptr(unsafe.Pointer(&data))); errno != 0 { - return fmt.Errorf("error setting signal mask: %v", errno) + _KVM_SET_MSRS, + uintptr(unsafe.Pointer(®isters))); errno != 0 { + return fmt.Errorf("error setting tsc: %v", errno) } return nil } // setUserRegisters sets user registers in the vCPU. -func (c *vCPU) setUserRegisters(uregs *userRegs) error { +// +//go:nosplit +func (c *vCPU) setUserRegisters(uregs *userRegs) syscall.Errno { if _, _, errno := syscall.RawSyscall( syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_REGS, uintptr(unsafe.Pointer(uregs))); errno != 0 { - return fmt.Errorf("error setting user registers: %v", errno) + return errno } - return nil + return 0 } // getUserRegisters reloads user registers in the vCPU. @@ -175,3 +146,17 @@ func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { } return nil } + +// getSystemRegisters sets system registers. +// +//go:nosplit +func (c *vCPU) getSystemRegisters(sregs *systemRegs) syscall.Errno { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_SREGS, + uintptr(unsafe.Pointer(sregs))); errno != 0 { + return errno + } + return 0 +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 9db171af9..2df762991 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -19,6 +19,7 @@ package kvm import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) @@ -48,6 +49,18 @@ const ( poolPCIDs = 8 ) +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + applyPhysicalRegions(func(pr physicalRegion) bool { + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|pr.virtual), + pr.length, + pagetables.MapOpts{AccessType: usermem.AnyAccess}, + pr.physical) + + return true // Keep iterating. + }) +} + // Get all read-only physicalRegions. func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { var rdonlyRegions []region diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 537419657..a163f956d 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -191,42 +191,6 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { return nil } -// setCPUID sets the CPUID to be used by the guest. -func (c *vCPU) setCPUID() error { - return nil -} - -// setSystemTime sets the TSC for the vCPU. -func (c *vCPU) setSystemTime() error { - return nil -} - -// setSignalMask sets the vCPU signal mask. -// -// This must be called prior to running the vCPU. -func (c *vCPU) setSignalMask() error { - // The layout of this structure implies that it will not necessarily be - // the same layout chosen by the Go compiler. It gets fudged here. - var data struct { - length uint32 - mask1 uint32 - mask2 uint32 - _ uint32 - } - data.length = 8 // Fixed sigset size. - data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) - data.mask2 = ^uint32(bounceSignalMask >> 32) - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_SIGNAL_MASK, - uintptr(unsafe.Pointer(&data))); errno != 0 { - return fmt.Errorf("error setting signal mask: %v", errno) - } - - return nil -} - // SwitchToUser unpacks architectural-details. func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) { // Check for canonical addresses. diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 607c82156..1d6ca245a 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -143,3 +143,29 @@ func (c *vCPU) waitUntilNot(state uint32) { panic("futex wait error") } } + +// setSignalMask sets the vCPU signal mask. +// +// This must be called prior to running the vCPU. +func (c *vCPU) setSignalMask() error { + // The layout of this structure implies that it will not necessarily be + // the same layout chosen by the Go compiler. It gets fudged here. + var data struct { + length uint32 + mask1 uint32 + mask2 uint32 + _ uint32 + } + data.length = 8 // Fixed sigset size. + data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) + data.mask2 = ^uint32(bounceSignalMask >> 32) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SIGNAL_MASK, + uintptr(unsafe.Pointer(&data))); errno != 0 { + return fmt.Errorf("error setting signal mask: %v", errno) + } + + return nil +} diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index 530e779b0..dcfe839a7 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/hostmm" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/usermem" ) @@ -52,6 +53,10 @@ type Platform interface { // can reliably return ErrContextCPUPreempted. DetectsCPUPreemption() bool + // HaveGlobalMemoryBarrier returns true if the GlobalMemoryBarrier method + // is supported. + HaveGlobalMemoryBarrier() bool + // MapUnit returns the alignment used for optional mappings into this // platform's AddressSpaces. Higher values indicate lower per-page costs // for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates @@ -97,6 +102,15 @@ type Platform interface { // called. PreemptAllCPUs() error + // GlobalMemoryBarrier blocks until all threads running application code + // (via Context.Switch) and all task goroutines "have passed through a + // state where all memory accesses to user-space addresses match program + // order between entry to and return from [GlobalMemoryBarrier]", as for + // membarrier(2). + // + // Preconditions: HaveGlobalMemoryBarrier() == true. + GlobalMemoryBarrier() error + // SyscallFilters returns syscalls made exclusively by this platform. SyscallFilters() seccomp.SyscallRules } @@ -115,6 +129,43 @@ func (NoCPUPreemptionDetection) PreemptAllCPUs() error { panic("This platform does not support CPU preemption detection") } +// UseHostGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and +// Platform.GlobalMemoryBarrier by invoking equivalent functionality on the +// host. +type UseHostGlobalMemoryBarrier struct{} + +// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier. +func (UseHostGlobalMemoryBarrier) HaveGlobalMemoryBarrier() bool { + return hostmm.HaveGlobalMemoryBarrier() +} + +// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier. +func (UseHostGlobalMemoryBarrier) GlobalMemoryBarrier() error { + return hostmm.GlobalMemoryBarrier() +} + +// UseHostProcessMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and +// Platform.GlobalMemoryBarrier by invoking a process-local memory barrier. +// This is faster than UseHostGlobalMemoryBarrier, but is only appropriate for +// platforms for which application code executes while using the sentry's +// mm_struct. +type UseHostProcessMemoryBarrier struct{} + +// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier. +func (UseHostProcessMemoryBarrier) HaveGlobalMemoryBarrier() bool { + // Fall back to a global memory barrier if a process-local one isn't + // available. + return hostmm.HaveProcessMemoryBarrier() || hostmm.HaveGlobalMemoryBarrier() +} + +// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier. +func (UseHostProcessMemoryBarrier) GlobalMemoryBarrier() error { + if hostmm.HaveProcessMemoryBarrier() { + return hostmm.ProcessMemoryBarrier() + } + return hostmm.GlobalMemoryBarrier() +} + // MemoryManager represents an abstraction above the platform address space // which manages memory mappings and their contents. type MemoryManager interface { diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index b52d0fbd8..f56aa3b79 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -192,6 +192,7 @@ func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {} type PTrace struct { platform.MMapMinAddr platform.NoCPUPreemptionDetection + platform.UseHostGlobalMemoryBarrier } // New returns a new ptrace-based implementation of the platform interface. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 9c6c2cf5c..f617519fa 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -76,15 +76,42 @@ type KernelOpts struct { type KernelArchState struct { KernelOpts + // cpuEntries is array of kernelEntry for all cpus + cpuEntries []kernelEntry + // globalIDT is our set of interrupt gates. - globalIDT idt64 + globalIDT *idt64 } -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { +// kernelEntry contains minimal CPU-specific arch state +// that can be mapped at the upper of the address space. +// Malicious APP might steal info from it via CPU bugs. +type kernelEntry struct { // stack is the stack used for interrupts on this CPU. stack [256]byte + // scratch space for temporary usage. + scratch0 uint64 + scratch1 uint64 + + // stackTop is the top of the stack. + stackTop uint64 + + // cpuSelf is back reference to CPU. + cpuSelf *CPU + + // kernelCR3 is the cr3 used for sentry kernel. + kernelCR3 uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { // errorCode is the error code from the last exception. errorCode uintptr @@ -97,11 +124,7 @@ type CPUArchState struct { // exception. errorType uintptr - // gdt is the CPU's descriptor table. - gdt descriptorTable - - // tss is the CPU's task state. - tss TaskState64 + *kernelEntry } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go index 7fa43c2f5..d87b1fd00 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.go +++ b/pkg/sentry/platform/ring0/entry_amd64.go @@ -36,12 +36,15 @@ func sysenter() // This must be called prior to sysret/iret. func swapgs() +// jumpToKernel jumps to the kernel version of the current RIP. +func jumpToKernel() + // sysret returns to userspace from a system call. // // The return code is the vector that interrupted execution. // // See stubs.go for a note regarding the frame size of this function. -func sysret(*CPU, *arch.Registers) Vector +func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // "iret is the cadillac of CPL switching." // @@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector // iret is nearly identical to sysret, except an iret is used to fully restore // all user state. This must be called in cases where all registers need to be // restored. -func iret(*CPU, *arch.Registers) Vector +func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // exception is the generic exception entry. // diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s index 02df38331..f59747df3 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/sentry/platform/ring0/entry_amd64.s @@ -63,6 +63,15 @@ MOVQ offset+PTRACE_RSI(reg), SI; \ MOVQ offset+PTRACE_RDI(reg), DI; +// WRITE_CR3() writes the given CR3 value. +// +// The code corresponds to: +// +// mov %rax, %cr3 +// +#define WRITE_CR3() \ + BYTE $0x0f; BYTE $0x22; BYTE $0xd8; + // SWAP_GS swaps the kernel GS (CPU). #define SWAP_GS() \ BYTE $0x0F; BYTE $0x01; BYTE $0xf8; @@ -75,15 +84,9 @@ #define SYSRET64() \ BYTE $0x48; BYTE $0x0f; BYTE $0x07; -// LOAD_KERNEL_ADDRESS loads a kernel address. -#define LOAD_KERNEL_ADDRESS(from, to) \ - MOVQ from, to; \ - ORQ ·KernelStartAddress(SB), to; - // LOAD_KERNEL_STACK loads the kernel stack. -#define LOAD_KERNEL_STACK(from) \ - LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \ - LEAQ CPU_STACK_TOP(SP), SP; +#define LOAD_KERNEL_STACK(entry) \ + MOVQ ENTRY_STACK_TOP(entry), SP; // See kernel.go. TEXT ·Halt(SB),NOSPLIT,$0 @@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0 SWAP_GS() RET +// jumpToKernel changes execution to the kernel address space. +// +// This works by changing the return value to the kernel version. +TEXT ·jumpToKernel(SB),NOSPLIT,$0 + MOVQ 0(SP), AX + ORQ ·KernelStartAddress(SB), AX // Future return value. + MOVQ AX, 0(SP) + RET + // See entry_amd64.go. TEXT ·sysret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) + // save SP AX userCR3 on the kernel stack. + MOVQ CPU_ENTRY(BX), BX + LOAD_KERNEL_STACK(BX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_RAX(AX) + PUSHQ CX + // Restore user register state. REGISTERS_LOAD(AX, 0) MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET. MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET. - MOVQ PTRACE_RSP(AX), SP // Restore the stack directly. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + + // restore userCR3, AX, SP. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. + POPQ SP // Restore SP. SYSRET64() // See entry_amd64.go. TEXT ·iret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) // Build an IRET frame & restore state. + MOVQ CPU_ENTRY(BX), BX LOAD_KERNEL_STACK(BX) - MOVQ PTRACE_SS(AX), BX; PUSHQ BX - MOVQ PTRACE_RSP(AX), CX; PUSHQ CX - MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX - MOVQ PTRACE_CS(AX), DI; PUSHQ DI - MOVQ PTRACE_RIP(AX), SI; PUSHQ SI - REGISTERS_LOAD(AX, 0) // Restore most registers. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + PUSHQ PTRACE_SS(AX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_FLAGS(AX) + PUSHQ PTRACE_CS(AX) + PUSHQ PTRACE_RIP(AX) + PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack. + PUSHQ CX // Save userCR3 on kernel stack. + REGISTERS_LOAD(AX, 0) // Restore most registers. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. IRET() // See entry_amd64.go. TEXT ·resume(SB),NOSPLIT,$0 // See iret, above. - MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX - MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX - MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI - MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI - REGISTERS_LOAD(GS, CPU_REGISTERS) - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + PUSHQ CPU_REGISTERS+PTRACE_SS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RSP(AX) + PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX) + PUSHQ CPU_REGISTERS+PTRACE_CS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RIP(AX) + REGISTERS_LOAD(AX, CPU_REGISTERS) + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX IRET() // See entry_amd64.go. TEXT ·Start(SB),NOSPLIT,$0 - LOAD_KERNEL_STACK(AX) // Set the stack. PUSHQ $0x0 // Previous frame pointer. MOVQ SP, BP // Set frame pointer. PUSHQ AX // First argument (CPU). @@ -155,53 +193,60 @@ TEXT ·Start(SB),NOSPLIT,$0 // See entry_amd64.go. TEXT ·sysenter(SB),NOSPLIT,$0 - // Interrupts are always disabled while we're executing in kernel mode - // and always enabled while executing in user mode. Therefore, we can - // reliably look at the flags in R11 to determine where this syscall - // was from. - TESTL $_RFLAGS_IF, R11 + // _RFLAGS_IOPL0 is always set in the user mode and it is never set in + // the kernel mode. See the comment of UserFlagsSet for more details. + TESTL $_RFLAGS_IOPL0, R11 JZ kernel - user: SWAP_GS() - XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks. - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs). + MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value. - MOVQ BX, PTRACE_RAX(AX) // Save everything else. - MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ CX, PTRACE_RIP(AX) MOVQ R11, PTRACE_FLAGS(AX) - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. + MOVQ SP, PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value. + MOVQ CX, PTRACE_RAX(AX) // Save everything else. + MOVQ CX, PTRACE_ORIGRAX(AX) + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks. + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. // Return to the kernel, where the frame is: // - // vector (sp+24) + // vector (sp+32) + // userCR3 (sp+24) // regs (sp+16) // cpu (sp+8) // vcpu.Switch (sp+0) // - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ $Syscall, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ $Syscall, 32(SP) // Output vector. RET kernel: // We can't restore the original stack, but we can access the registers // in the CPU state directly. No need for temporary juggling. - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ AX, ENTRY_SCRATCH0(GS) + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), BX + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. // Call the syscall trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ AX // First argument (vCPU). CALL ·kernelSyscall(SB) // Call the trampoline. POPQ AX // Pop vCPU. @@ -230,16 +275,21 @@ TEXT ·exception(SB),NOSPLIT,$0 // ERROR_CODE (sp+8) // VECTOR (sp+0) // - TESTL $_RFLAGS_IF, 32(SP) + TESTL $_RFLAGS_IOPL0, 32(SP) JZ kernel user: SWAP_GS() ADDQ $-8, SP // Adjust for flags. MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ). - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs. + PUSHQ AX // Save user AX on stack. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX. + POPQ BX // Restore original AX. MOVQ BX, PTRACE_RAX(AX) // Save it. MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX) @@ -249,34 +299,36 @@ user: MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX) // Copy out and return. + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. MOVQ 0(SP), BX // Load vector. MOVQ 8(SP), CX // Load error code. - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version). - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ CX, CPU_ERROR_CODE(GS) // Set error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. - MOVQ BX, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version). + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ CX, CPU_ERROR_CODE(AX) // Set error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. + MOVQ BX, 32(SP) // Output vector. RET kernel: // As per above, we can save directly. - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS) + PUSHQ AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + POPQ BX + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX) // Set the error code and adjust the stack. - MOVQ 8(SP), AX // Load the error code. - MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ 8(SP), BX // Load the error code. + MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. MOVQ 0(SP), BX // BX contains the vector. - ADDQ $48, SP // Drop the exception frame. // Call the exception trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ BX // Second argument (vector). PUSHQ AX // First argument (vCPU). CALL ·kernelException(SB) // Call the trampoline. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index e173b6d4e..253f54e49 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -467,6 +467,14 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 MOVD PTRACE_PSTATE(RSV_REG_APP), R1 WORD $0xd5184001 //MSR R1, SPSR_EL1 + // need use kernel space address to excute below code, since + // after SWITCH_TO_APP_PAGETABLE the ASID is changed to app's + // ASID. + WORD $0x10000061 // ADR R1, do_exit_to_el0 + ORR $0xffff000000000000, R1, R1 + JMP (R1) + +do_exit_to_el0: // RSV_REG & RSV_REG_APP will be loaded at the end. REGISTERS_LOAD(RSV_REG_APP, 0) MOVD PTRACE_TLS(RSV_REG_APP), RSV_REG diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 549f3d228..9742308d8 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,7 +24,10 @@ go_binary( "defs_impl_arm64.go", "main.go", ], - visibility = ["//pkg/sentry/platform/ring0:__pkg__"], + visibility = [ + "//pkg/sentry/platform/kvm:__pkg__", + "//pkg/sentry/platform/ring0:__pkg__", + ], deps = [ "//pkg/cpuid", "//pkg/sentry/arch", diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 021693791..264be23d3 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -19,8 +19,8 @@ package ring0 // N.B. that constraints on KernelOpts must be satisfied. // //go:nosplit -func (k *Kernel) Init(opts KernelOpts) { - k.init(opts) +func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { + k.init(opts, maxCPUs) } // Halt halts execution. @@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) { // kernelSyscall is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) { // kernelException is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) { // Init initializes a new CPU. // // Init allows embedding in other objects. -func (c *CPU) Init(k *Kernel, hooks Hooks) { - c.self = c // Set self reference. - c.kernel = k // Set kernel reference. - c.init() // Perform architectural init. +func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) { + c.self = c // Set self reference. + c.kernel = k // Set kernel reference. + c.init(cpuID) // Perform architectural init. // Require hooks. if hooks != nil { diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index d37981dbf..3a9dff4cc 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -18,13 +18,42 @@ package ring0 import ( "encoding/binary" + "reflect" + + "gvisor.dev/gvisor/pkg/usermem" ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables + entrySize := reflect.TypeOf(kernelEntry{}).Size() + var ( + entries []kernelEntry + padding = 1 + ) + for { + entries = make([]kernelEntry, maxCPUs+padding-1) + totalSize := entrySize * uintptr(maxCPUs+padding-1) + addr := reflect.ValueOf(&entries[0]).Pointer() + if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize { + // The runtime forces power-of-2 alignment for allocations, and we are therefore + // safe once the first address is aligned and the chunk is at least a full page. + break + } + padding = padding << 1 + } + k.cpuEntries = entries + + k.globalIDT = &idt64{} + if reflect.TypeOf(idt64{}).Size() != usermem.PageSize { + panic("Size of globalIDT should be PageSize") + } + if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 { + panic("Allocated globalIDT should be page aligned") + } + // Setup the IDT, which is uniform. for v, handler := range handlers { // Allow Breakpoint and Overflow to be called from all @@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) { } } +func (k *Kernel) EntryRegions() map[uintptr]uintptr { + regions := make(map[uintptr]uintptr) + + addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer() + size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries)) + end, _ := usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + addr = reflect.ValueOf(k.globalIDT).Pointer() + size = reflect.TypeOf(idt64{}).Size() + end, _ = usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + return regions +} + // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { + c.kernelEntry = &c.kernel.cpuEntries[cpuID] + c.cpuSelf = c // Null segment. c.gdt[0].setNull() @@ -65,6 +112,7 @@ func (c *CPU) init() { // Set the kernel stack pointer in the TSS (virtual address). stackAddr := c.StackTop() + c.stackTop = stackAddr c.tss.rsp0Lo = uint32(stackAddr) c.tss.rsp0Hi = uint32(stackAddr >> 32) c.tss.ist1Lo = uint32(stackAddr) @@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID) - kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID) + c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)) // Sanitize registers. regs := switchOpts.Registers @@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point. - jumpToKernel() // Switch to upper half. - writeCR3(uintptr(userCR3)) // Change to user address space. if switchOpts.FullRestore { - vector = iret(c, regs) + vector = iret(c, regs, uintptr(userCR3)) } else { - vector = sysret(c, regs) + vector = sysret(c, regs, uintptr(userCR3)) } - writeCR3(uintptr(kernelCR3)) // Return to kernel address space. - jumpToUser() // Return to lower half. SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point. WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return @@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { //go:nosplit func start(c *CPU) { // Save per-cpu & FS segment. - WriteGS(kernelAddr(c)) + WriteGS(kernelAddr(c.kernelEntry)) WriteFS(uintptr(c.registers.Fs_base)) // Initialize floating point. diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 60106f1e0..b294ccc7c 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -25,13 +25,13 @@ func HaltAndResume() func HaltEl1SvcAndResume() // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables } // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { // Set the kernel stack pointer(virtual address). c.registers.Sp = uint64(c.StackTop()) diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go index ca968a036..0ec5c3bc5 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.go +++ b/pkg/sentry/platform/ring0/lib_amd64.go @@ -61,21 +61,9 @@ func wrgsbase(addr uintptr) // wrgsmsr writes to the GS_BASE MSR. func wrgsmsr(addr uintptr) -// writeCR3 writes the CR3 value. -func writeCR3(phys uintptr) - -// readCR3 reads the current CR3 value. -func readCR3() uintptr - // readCR2 reads the current CR2 value. func readCR2() uintptr -// jumpToKernel jumps to the kernel version of the current RIP. -func jumpToKernel() - -// jumpToUser jumps to the user version of the current RIP. -func jumpToUser() - // fninit initializes the floating point unit. func fninit() diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s index 75d742750..2fe83568a 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.s +++ b/pkg/sentry/platform/ring0/lib_amd64.s @@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8 BYTE $0x0f; BYTE $0x30; // WRMSR RET -// jumpToUser changes execution to the user address. -// -// This works by changing the return value to the user version. -TEXT ·jumpToUser(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - NOTQ BX - ANDQ BX, SP // Switch the stack. - ANDQ BX, BP // Switch the frame pointer. - ANDQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// jumpToKernel changes execution to the kernel address space. -// -// This works by changing the return value to the kernel version. -TEXT ·jumpToKernel(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - ORQ BX, SP // Switch the stack. - ORQ BX, BP // Switch the frame pointer. - ORQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// writeCR3 writes the given CR3 value. -// -// The code corresponds to: -// -// mov %rax, %cr3 -// -TEXT ·writeCR3(SB),NOSPLIT,$0-8 - MOVQ cr3+0(FP), AX - BYTE $0x0f; BYTE $0x22; BYTE $0xd8; - RET - -// readCR3 reads the current CR3 value. -// -// The code corresponds to: -// -// mov %cr3, %rax -// -TEXT ·readCR3(SB),NOSPLIT,$0-8 - BYTE $0x0f; BYTE $0x20; BYTE $0xd8; - MOVQ AX, ret+0(FP) - RET - // readCR2 reads the current CR2 value. // // The code corresponds to: diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go index b8ab120a0..290d94bd6 100644 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ b/pkg/sentry/platform/ring0/offsets_amd64.go @@ -30,14 +30,22 @@ func Emit(w io.Writer) { c := &CPU{} fmt.Fprintf(w, "\n// CPU offsets.\n") - fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) + + e := &kernelEntry{} + fmt.Fprintf(w, "\n// CPU entry offsets.\n") + fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_SCRATCH1 0x%02x\n", reflect.ValueOf(&e.scratch1).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) + fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0) fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) fmt.Fprintf(w, "\n// Vectors.\n") diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go index 9da0ea685..34fbc1c35 100644 --- a/pkg/sentry/platform/ring0/x86.go +++ b/pkg/sentry/platform/ring0/x86.go @@ -39,7 +39,9 @@ const ( _RFLAGS_AC = 1 << 18 _RFLAGS_NT = 1 << 14 - _RFLAGS_IOPL = 3 << 12 + _RFLAGS_IOPL0 = 1 << 12 + _RFLAGS_IOPL1 = 1 << 13 + _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1 _RFLAGS_DF = 1 << 10 _RFLAGS_IF = 1 << 9 _RFLAGS_STEP = 1 << 8 @@ -67,15 +69,45 @@ const ( KernelFlagsSet = _RFLAGS_RESERVED // UserFlagsSet are always set in userspace. - UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF + // + // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege + // level. The Current Privilege Level (CPL) of the task must be less + // than or equal to the IOPL in order for the task or program to access + // I/O ports. + // + // Here, _RFLAGS_IOPL0 is used only to determine whether the task is + // running in the kernel or userspace mode. In the user mode, the CPL is + // always 3 and it doesn't matter what IOPL is set if it is bellow CPL. + // + // We need to have one bit which will be always different in user and + // kernel modes. And we have to remember that even though we have + // KernelFlagsClear, we still can see some of these flags in the kernel + // mode. This can happen when the goruntime switches on a goroutine + // which has been saved in the host mode. On restore, the popf + // instruction is used to restore flags and this means that all flags + // what the goroutine has in the host mode will be restored in the + // kernel mode. + // + // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set + // it in the user mode. So if this flag is set, the task is running in + // the user mode and if it isn't set, the task is running in the kernel + // mode. + UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0 // KernelFlagsClear should always be clear in the kernel. KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT // UserFlagsClear are always cleared in userspace. - UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL + UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1 ) +// IsKernelFlags returns true if rflags coresponds to the kernel mode. +// +// go:nosplit +func IsKernelFlags(rflags uint64) bool { + return rflags&_RFLAGS_IOPL0 == 0 +} + // Vector is an exception vector. type Vector uintptr @@ -104,7 +136,7 @@ const ( VirtualizationException SecurityException = 0x1e SyscallInt80 = 0x80 - _NR_INTERRUPTS = SyscallInt80 + 1 + _NR_INTERRUPTS = 0x100 ) // System call vectors. diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 87b077e68..163af329b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -78,6 +78,13 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *socketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 0336a32d8..549787955 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -37,7 +39,7 @@ type matchMaker interface { // name is the matcher name as stored in the xt_entry_match struct. name() string - // marshal converts from an stack.Matcher to an ABI struct. + // marshal converts from a stack.Matcher to an ABI struct. marshal(matcher stack.Matcher) []byte // unmarshal converts from the ABI matcher struct to an @@ -93,3 +95,71 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf } return matchMaker.unmarshal(buf, filter) } + +// targetMaker knows how to (un)marshal a target. Once registered, +// marshalTarget and unmarshalTarget can be used. +type targetMaker interface { + // id uniquely identifies the target. + id() stack.TargetID + + // marshal converts from a stack.Target to an ABI struct. + marshal(target stack.Target) []byte + + // unmarshal converts from the ABI matcher struct to a stack.Target. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) +} + +// targetMakers maps the TargetID of supported targets to the targetMaker that +// marshals and unmarshals it. It is immutable after package initialization. +var targetMakers = map[stack.TargetID]targetMaker{} + +func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) { + tid := stack.TargetID{ + Name: name, + NetworkProtocol: netProto, + Revision: rev, + } + if _, ok := targetMakers[tid]; !ok { + return 0, false + } + + // Return the highest supported revision unless rev is higher. + for _, other := range targetMakers { + otherID := other.id() + if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev { + rev = uint8(otherID.Revision) + } + } + return rev, true +} + +// registerTargetMaker should be called by target extensions to register them +// with the netfilter package. +func registerTargetMaker(tm targetMaker) { + if _, ok := targetMakers[tm.id()]; ok { + panic(fmt.Sprintf("multiple targets registered with name %q.", tm.id())) + } + targetMakers[tm.id()] = tm +} + +func marshalTarget(target stack.Target) []byte { + targetMaker, ok := targetMakers[target.ID()] + if !ok { + panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID())) + } + return targetMaker.marshal(target) +} + +func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) { + tid := stack.TargetID{ + Name: target.Name.String(), + NetworkProtocol: filter.NetworkProtocol(), + Revision: target.Revision, + } + targetMaker, ok := targetMakers[tid] + if !ok { + nflog("unsupported target with name %q", target.Name.String()) + return nil, syserr.ErrInvalidArgument + } + return targetMaker.unmarshal(buf, filter) +} diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index e4c55a100..b560fae0d 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -181,18 +181,23 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - target, err := parseTarget(filter, optVal[:targetSize]) - if err != nil { - nflog("failed to parse target: %v", err) - return nil, syserr.ErrInvalidArgument - } - optVal = optVal[targetSize:] - table.Rules = append(table.Rules, stack.Rule{ + rule := stack.Rule{ Filter: filter, - Target: target, Matchers: matchers, - }) + } + + { + target, err := parseTarget(filter, optVal[:targetSize], false /* ipv6 */) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, err + } + rule.Target = target + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, rule) offsets[offset] = int(entryIdx) offset += uint32(entry.NextOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 3b2c1becd..4253f7bf4 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -184,18 +184,23 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) return nil, syserr.ErrInvalidArgument } - target, err := parseTarget(filter, optVal[:targetSize]) - if err != nil { - nflog("failed to parse target: %v", err) - return nil, syserr.ErrInvalidArgument - } - optVal = optVal[targetSize:] - table.Rules = append(table.Rules, stack.Rule{ + rule := stack.Rule{ Filter: filter, - Target: target, Matchers: matchers, - }) + } + + { + target, err := parseTarget(filter, optVal[:targetSize], true /* ipv6 */) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, err + } + rule.Target = target + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, rule) offsets[offset] = int(entryIdx) offset += uint32(entry.NextOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 871ea80ee..904a12e38 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -146,10 +147,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { case stack.FilterTable: table = stack.EmptyFilterTable() case stack.NATTable: - if ipv6 { - nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)") - return syserr.ErrInvalidArgument - } table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) @@ -199,7 +196,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // Check the user chains. for ruleIdx, rule := range table.Rules { - if _, ok := rule.Target.(stack.UserChainTarget); !ok { + if _, ok := rule.Target.(*stack.UserChainTarget); !ok { continue } @@ -220,7 +217,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // Set each jump to point to the appropriate rule. Right now they hold byte // offsets. for ruleIdx, rule := range table.Rules { - jump, ok := rule.Target.(JumpTarget) + jump, ok := rule.Target.(*JumpTarget) if !ok { continue } @@ -311,7 +308,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool { return false } switch rule.Target.(type) { - case stack.AcceptTarget, stack.DropTarget: + case *stack.AcceptTarget, *stack.DropTarget: return true default: return false @@ -322,7 +319,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { if !validUnderflow(rule, ipv6) { return false } - _, ok := rule.Target.(stack.AcceptTarget) + _, ok := rule.Target.(*stack.AcceptTarget) return ok } @@ -341,3 +338,20 @@ func hookFromLinux(hook int) stack.Hook { } panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } + +// TargetRevision returns a linux.XTGetRevision for a given target. It sets +// Revision to the highest supported value, unless the provided revision number +// is larger. +func TargetRevision(t *kernel.Task, revPtr usermem.Addr, netProto tcpip.NetworkProtocolNumber) (linux.XTGetRevision, *syserr.Error) { + // Read in the target name and version. + var rev linux.XTGetRevision + if _, err := rev.CopyIn(t, revPtr); err != nil { + return linux.XTGetRevision{}, syserr.FromError(err) + } + maxSupported, ok := targetRevision(rev.Name.String(), netProto, rev.Revision) + if !ok { + return linux.XTGetRevision{}, syserr.ErrProtocolNotSupported + } + rev.Revision = maxSupported + return rev, nil +} diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 87e41abd8..0e14447fe 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,255 +15,357 @@ package netfilter import ( - "errors" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) -// errorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const errorTargetName = "ERROR" +func init() { + // Standard targets include ACCEPT, DROP, RETURN, and JUMP. + registerTargetMaker(&standardTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&standardTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) + + // Both user chains and actual errors are represented in iptables by + // error targets. + registerTargetMaker(&errorTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&errorTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) + + registerTargetMaker(&redirectTargetMaker{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&nfNATTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) +} -// redirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const redirectTargetName = "REDIRECT" +type standardTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} -func marshalTarget(target stack.Target) []byte { +func (sm *standardTargetMaker) id() stack.TargetID { + // Standard targets have the empty string as a name and no revisions. + return stack.TargetID{ + NetworkProtocol: sm.NetworkProtocol, + } +} +func (*standardTargetMaker) marshal(target stack.Target) []byte { + // Translate verdicts the same way as the iptables tool. + var verdict int32 switch tg := target.(type) { - case stack.AcceptTarget: - return marshalStandardTarget(stack.RuleAccept) - case stack.DropTarget: - return marshalStandardTarget(stack.RuleDrop) - case stack.ErrorTarget: - return marshalErrorTarget(errorTargetName) - case stack.UserChainTarget: - return marshalErrorTarget(tg.Name) - case stack.ReturnTarget: - return marshalStandardTarget(stack.RuleReturn) - case stack.RedirectTarget: - return marshalRedirectTarget(tg) - case JumpTarget: - return marshalJumpTarget(tg) + case *stack.AcceptTarget: + verdict = -linux.NF_ACCEPT - 1 + case *stack.DropTarget: + verdict = -linux.NF_DROP - 1 + case *stack.ReturnTarget: + verdict = linux.NF_RETURN + case *JumpTarget: + verdict = int32(tg.Offset) default: panic(fmt.Errorf("unknown target of type %T", target)) } -} - -func marshalStandardTarget(verdict stack.RuleVerdict) []byte { - nflog("convert to binary: marshalling standard target") // The target's name will be the empty string. - target := linux.XTStandardTarget{ + xt := linux.XTStandardTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTStandardTarget, }, - Verdict: translateFromStandardVerdict(verdict), + Verdict: verdict, } ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + return binary.Marshal(ret, usermem.ByteOrder, xt) +} + +func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) != linux.SizeOfXTStandardTarget { + nflog("buf has wrong size for standard target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + var standardTarget linux.XTStandardTarget + buf = buf[:linux.SizeOfXTStandardTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) + + if standardTarget.Verdict < 0 { + // A Verdict < 0 indicates a non-jump verdict. + return translateToStandardTarget(standardTarget.Verdict, filter.NetworkProtocol()) + } + // A verdict >= 0 indicates a jump. + return &JumpTarget{ + Offset: uint32(standardTarget.Verdict), + NetworkProtocol: filter.NetworkProtocol(), + }, nil +} + +type errorTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (em *errorTargetMaker) id() stack.TargetID { + // Error targets have no revision. + return stack.TargetID{ + Name: stack.ErrorTargetName, + NetworkProtocol: em.NetworkProtocol, + } } -func marshalErrorTarget(errorName string) []byte { +func (*errorTargetMaker) marshal(target stack.Target) []byte { + var errorName string + switch tg := target.(type) { + case *stack.ErrorTarget: + errorName = stack.ErrorTargetName + case *stack.UserChainTarget: + errorName = tg.Name + default: + panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target)) + } + // This is an error target named error - target := linux.XTErrorTarget{ + xt := linux.XTErrorTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTErrorTarget, }, } - copy(target.Name[:], errorName) - copy(target.Target.Name[:], errorTargetName) + copy(xt.Name[:], errorName) + copy(xt.Target.Name[:], stack.ErrorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + return binary.Marshal(ret, usermem.ByteOrder, xt) } -func marshalRedirectTarget(rt stack.RedirectTarget) []byte { +func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) != linux.SizeOfXTErrorTarget { + nflog("buf has insufficient size for error target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + var errorTarget linux.XTErrorTarget + buf = buf[:linux.SizeOfXTErrorTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + + // Error targets are used in 2 cases: + // * An actual error case. These rules have an error + // named stack.ErrorTargetName. The last entry of the table + // is usually an error case to catch any packets that + // somehow fall through every rule. + // * To mark the start of a user defined chain. These + // rules have an error with the name of the chain. + switch name := errorTarget.Name.String(); name { + case stack.ErrorTargetName: + return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil + default: + // User defined chain. + return &stack.UserChainTarget{ + Name: name, + NetworkProtocol: filter.NetworkProtocol(), + }, nil + } +} + +type redirectTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (rm *redirectTargetMaker) id() stack.TargetID { + return stack.TargetID{ + Name: stack.RedirectTargetName, + NetworkProtocol: rm.NetworkProtocol, + } +} + +func (*redirectTargetMaker) marshal(target stack.Target) []byte { + rt := target.(*stack.RedirectTarget) // This is a redirect target named redirect - target := linux.XTRedirectTarget{ + xt := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTRedirectTarget, }, } - copy(target.Target.Name[:], redirectTargetName) + copy(xt.Target.Name[:], stack.RedirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) - target.NfRange.RangeSize = 1 - if rt.RangeProtoSpecified { - target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeSize = 1 + xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) + xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort + return binary.Marshal(ret, usermem.ByteOrder, xt) +} + +func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if len(buf) < linux.SizeOfXTRedirectTarget { + nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("redirectTargetMaker: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var redirectTarget linux.XTRedirectTarget + buf = buf[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()} + + // RangeSize should be 1. + nfRange := redirectTarget.NfRange + if nfRange.RangeSize != 1 { + nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports destination port change. + // Port range and IP range are not supported yet. + if nfRange.RangeIPV4.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("redirectTargetMaker: invalid range flags %d", nfRange.RangeIPV4.Flags) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + nflog("redirectTargetMaker: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument } - // Convert port from little endian to big endian. - port := make([]byte, 2) - binary.LittleEndian.PutUint16(port, rt.MinPort) - target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) - binary.LittleEndian.PutUint16(port, rt.MaxPort) - target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) - return binary.Marshal(ret, usermem.ByteOrder, target) + if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP { + nflog("redirectTargetMaker: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + + target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.Port = ntohs(nfRange.RangeIPV4.MinPort) + + return &target, nil } -func marshalJumpTarget(jt JumpTarget) []byte { - nflog("convert to binary: marshalling jump target") +type nfNATTarget struct { + Target linux.XTEntryTarget + Range linux.NFNATRange +} - // The target's name will be the empty string. - target := linux.XTStandardTarget{ +const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange + +type nfNATTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (rm *nfNATTargetMaker) id() stack.TargetID { + return stack.TargetID{ + Name: stack.RedirectTargetName, + NetworkProtocol: rm.NetworkProtocol, + } +} + +func (*nfNATTargetMaker) marshal(target stack.Target) []byte { + rt := target.(*stack.RedirectTarget) + nt := nfNATTarget{ Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, + TargetSize: nfNATMarhsalledSize, + }, + Range: linux.NFNATRange{ + Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, }, - // Verdict is overloaded by the ABI. When positive, it holds - // the jump offset from the start of the table. - Verdict: int32(jt.Offset), } + copy(nt.Target.Name[:], stack.RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.Addr) + copy(nt.Range.MaxAddr[:], rt.Addr) - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + nt.Range.MinProto = htons(rt.Port) + nt.Range.MaxProto = nt.Range.MinProto + + ret := make([]byte, 0, nfNATMarhsalledSize) + return binary.Marshal(ret, usermem.ByteOrder, nt) } -// translateFromStandardVerdict translates verdicts the same way as the iptables -// tool. -func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { - switch verdict { - case stack.RuleAccept: - return -linux.NF_ACCEPT - 1 - case stack.RuleDrop: - return -linux.NF_DROP - 1 - case stack.RuleReturn: - return linux.NF_RETURN - default: - // TODO(gvisor.dev/issue/170): Support Jump. - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if size := nfNATMarhsalledSize; len(buf) < size { + nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) + return nil, syserr.ErrInvalidArgument } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("nfNATTargetMaker: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var natRange linux.NFNATRange + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize] + binary.Unmarshal(buf, usermem.ByteOrder, &natRange) + + // We don't support port or address ranges. + if natRange.MinAddr != natRange.MaxAddr { + nflog("nfNATTargetMaker: MinAddr and MaxAddr are different") + return nil, syserr.ErrInvalidArgument + } + if natRange.MinProto != natRange.MaxProto { + nflog("nfNATTargetMaker: MinProto and MaxProto are different") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/3549): Check for other flags. + // For now, redirect target only supports destination change. + if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags) + return nil, syserr.ErrInvalidArgument + } + + target := stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Addr: tcpip.Address(natRange.MinAddr[:]), + Port: ntohs(natRange.MinProto), + } + + return &target, nil } // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32) (stack.Target, error) { +func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return stack.AcceptTarget{}, nil + return &stack.AcceptTarget{NetworkProtocol: netProto}, nil case -linux.NF_DROP - 1: - return stack.DropTarget{}, nil + return &stack.DropTarget{NetworkProtocol: netProto}, nil case -linux.NF_QUEUE - 1: - return nil, errors.New("unsupported iptables verdict QUEUE") + nflog("unsupported iptables verdict QUEUE") + return nil, syserr.ErrInvalidArgument case linux.NF_RETURN: - return stack.ReturnTarget{}, nil + return &stack.ReturnTarget{NetworkProtocol: netProto}, nil default: - return nil, fmt.Errorf("unknown iptables verdict %d", val) + nflog("unknown iptables verdict %d", val) + return nil, syserr.ErrInvalidArgument } } // parseTarget parses a target from optVal. optVal should contain only the // target. -func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { +func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.Target, *syserr.Error) { nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { - return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) + nflog("optVal has insufficient size for entry target %d", len(optVal)) + return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget buf := optVal[:linux.SizeOfXTEntryTarget] binary.Unmarshal(buf, usermem.ByteOrder, &target) - switch target.Name.String() { - case "": - // Standard target. - if len(optVal) != linux.SizeOfXTStandardTarget { - return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) - } - var standardTarget linux.XTStandardTarget - buf = optVal[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - - if standardTarget.Verdict < 0 { - // A Verdict < 0 indicates a non-jump verdict. - return translateToStandardTarget(standardTarget.Verdict) - } - // A verdict >= 0 indicates a jump. - return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil - - case errorTargetName: - // Error target. - if len(optVal) != linux.SizeOfXTErrorTarget { - return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) - } - var errorTarget linux.XTErrorTarget - buf = optVal[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) - - // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named errorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. - // * To mark the start of a user defined chain. These - // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case errorTargetName: - nflog("set entries: error target") - return stack.ErrorTarget{}, nil - default: - // User defined chain. - nflog("set entries: user-defined target %q", name) - return stack.UserChainTarget{Name: name}, nil - } - - case redirectTargetName: - // Redirect target. - if len(optVal) < linux.SizeOfXTRedirectTarget { - return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) - } - - if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { - return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p) - } - - var redirectTarget linux.XTRedirectTarget - buf = optVal[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) - - // Copy linux.XTRedirectTarget to stack.RedirectTarget. - var target stack.RedirectTarget - nfRange := redirectTarget.NfRange - - // RangeSize should be 1. - if nfRange.RangeSize != 1 { - return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize) - } - - // TODO(gvisor.dev/issue/170): Check if the flags are valid. - // Also check if we need to map ports or IP. - // For now, redirect target only supports destination port change. - // Port range and IP range are not supported yet. - if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags) - } - target.RangeProtoSpecified = true - - target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) - target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) - - // TODO(gvisor.dev/issue/170): Port range is not supported yet. - if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { - return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) - } - - // Convert port from big endian to little endian. - port := make([]byte, 2) - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) - target.MinPort = binary.LittleEndian.Uint16(port) - - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) - target.MaxPort = binary.LittleEndian.Uint16(port) - return target, nil - } - // Unknown target. - return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String()) + return unmarshalTarget(target, filter, optVal) } // JumpTarget implements stack.Target. @@ -274,9 +376,31 @@ type JumpTarget struct { // RuleNum is the rule to jump to. RuleNum int + + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (jt *JumpTarget) ID() stack.TargetID { + return stack.TargetID{ + NetworkProtocol: jt.NetworkProtocol, + } } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } + +func ntohs(port uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, port) + return usermem.ByteOrder.Uint16(buf) +} + +func htons(port uint16) uint16 { + buf := make([]byte, 2) + usermem.ByteOrder.PutUint16(buf, port) + return binary.BigEndian.Uint16(buf) +} diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index a38d25da9..c83b23242 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -82,6 +82,13 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV return fd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 6fede181a..87e30d742 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -198,7 +198,6 @@ var Metrics = tcpip.Stats{ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), - InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -1513,8 +1512,17 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return &vP, nil case linux.IP6T_ORIGINAL_DST: - // TODO(gvisor.dev/issue/170): ip6tables. - return nil, syserr.ErrInvalidArgument + if outLen < int(binary.Size(linux.SockAddrInet6{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: if outLen < linux.SizeOfIPTGetinfo { @@ -1556,6 +1564,26 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } return &entries, nil + case linux.IP6T_SO_GET_REVISION_TARGET: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber) + if err != nil { + return nil, err + } + return &ret, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1719,6 +1747,26 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in } return &entries, nil + case linux.IPT_SO_GET_REVISION_TARGET: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + ret, err := netfilter.TargetRevision(t, outPtr, header.IPv4ProtocolNumber) + if err != nil { + return nil, err + } + return &ret, nil + default: emitUnimplementedEventIP(t, name) } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index c0212ad76..4c6791fff 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -79,6 +79,13 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index a89583dad..cc7408698 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -7,10 +7,21 @@ go_template_instance( name = "socket_refs", out = "socket_refs.go", package = "unix", - prefix = "socketOpsCommon", + prefix = "socketOperations", template = "//pkg/refs_vfs2:refs_template", types = { - "T": "socketOpsCommon", + "T": "SocketOperations", + }, +) + +go_template_instance( + name = "socket_vfs2_refs", + out = "socket_vfs2_refs.go", + package = "unix", + prefix = "socketVFS2", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "SocketVFS2", }, ) @@ -20,6 +31,7 @@ go_library( "device.go", "io.go", "socket_refs.go", + "socket_vfs2_refs.go", "unix.go", "unix_vfs2.go", ], diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 917055cea..f80011ce4 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -55,6 +55,7 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + socketOperationsRefs socketOpsCommon } @@ -84,11 +85,27 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty return fs.NewFile(ctx, d, flags, &s) } +// DecRef implements RefCounter.DecRef. +func (s *SocketOperations) DecRef(ctx context.Context) { + s.socketOperationsRefs.DecRef(func() { + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implemements fs.FileOperations.Release. +func (s *SocketOperations) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // socketOpsCommon contains the socket operations common to VFS1 and VFS2. // // +stateify savable type socketOpsCommon struct { - socketOpsCommonRefs socket.SendReceiveTimeout ep transport.Endpoint @@ -101,23 +118,6 @@ type socketOpsCommon struct { abstractNamespace *kernel.AbstractSocketNamespace } -// DecRef implements RefCounter.DecRef. -func (s *socketOpsCommon) DecRef(ctx context.Context) { - s.socketOpsCommonRefs.DecRef(func() { - s.ep.Close(ctx) - if s.abstractNamespace != nil { - s.abstractNamespace.Remove(s.abstractName, s) - } - }) -} - -// Release implemements fs.FileOperations.Release. -func (s *socketOpsCommon) Release(ctx context.Context) { - // Release only decrements a reference on s because s may be referenced in - // the abstract socket namespace. - s.DecRef(ctx) -} - func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 8b1abd922..3345124cc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -45,6 +45,7 @@ type SocketVFS2 struct { vfs.DentryMetadataFileDescriptionImpl vfs.LockFD + socketVFS2Refs socketOpsCommon } @@ -91,6 +92,25 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 return vfsfd, nil } +// DecRef implements RefCounter.DecRef. +func (s *SocketVFS2) DecRef(ctx context.Context) { + s.socketVFS2Refs.DecRef(func() { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 52281ccc2..396744597 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -17,7 +17,6 @@ package strace import ( - "encoding/binary" "fmt" "strconv" "strings" @@ -294,7 +293,7 @@ func itimerval(t *kernel.Task, addr usermem.Addr) string { } interval := timeval(t, addr) - value := timeval(t, addr+usermem.Addr(binary.Size(linux.Timeval{}))) + value := timeval(t, addr+usermem.Addr((*linux.Timeval)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } @@ -304,7 +303,7 @@ func itimerspec(t *kernel.Task, addr usermem.Addr) string { } interval := timespec(t, addr) - value := timespec(t, addr+usermem.Addr(binary.Size(linux.Timespec{}))) + value := timespec(t, addr+usermem.Addr((*linux.Timespec)(nil).SizeBytes())) return fmt.Sprintf("%#x {interval=%s, value=%s}", addr, interval, value) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 75752b2e6..a2e441448 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -21,6 +21,7 @@ go_library( "sys_identity.go", "sys_inotify.go", "sys_lseek.go", + "sys_membarrier.go", "sys_mempolicy.go", "sys_mmap.go", "sys_mount.go", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 5f26697d2..9c9def7cd 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -376,7 +376,7 @@ var AMD64 = &kernel.SyscallTable{ 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), 322: syscalls.Supported("execveat", Execveat), 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 324: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil), 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), // Syscalls implemented after 325 are "backports" from versions @@ -527,8 +527,8 @@ var ARM64 = &kernel.SyscallTable{ 96: syscalls.Supported("set_tid_address", SetTidAddress), 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), - 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 99: syscalls.Supported("set_robust_list", SetRobustList), + 100: syscalls.Supported("get_robust_list", GetRobustList), 101: syscalls.Supported("nanosleep", Nanosleep), 102: syscalls.Supported("getitimer", Getitimer), 103: syscalls.Supported("setitimer", Setitimer), @@ -695,7 +695,7 @@ var ARM64 = &kernel.SyscallTable{ 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), 281: syscalls.Supported("execveat", Execveat), 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 283: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil), 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), // Syscalls after 284 are "backports" from versions of Linux after 4.4. diff --git a/pkg/sentry/syscalls/linux/sys_membarrier.go b/pkg/sentry/syscalls/linux/sys_membarrier.go new file mode 100644 index 000000000..288cd8512 --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_membarrier.go @@ -0,0 +1,70 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Membarrier implements syscall membarrier(2). +func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + cmd := args[0].Int() + flags := args[1].Int() + + p := t.Kernel().Platform + if !p.HaveGlobalMemoryBarrier() { + // Event for applications that want membarrier on a configuration that + // doesn't support them. + t.Kernel().EmitUnimplementedEvent(t) + return 0, nil, syserror.ENOSYS + } + + if flags != 0 { + return 0, nil, syserror.EINVAL + } + + switch cmd { + case linux.MEMBARRIER_CMD_QUERY: + const supportedCommands = linux.MEMBARRIER_CMD_GLOBAL | + linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED | + linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED | + linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED | + linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED + return supportedCommands, nil, nil + case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED: + if !t.MemoryManager().IsMembarrierPrivateEnabled() { + return 0, nil, syserror.EPERM + } + fallthrough + case linux.MEMBARRIER_CMD_GLOBAL, linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED: + return 0, nil, p.GlobalMemoryBarrier() + case linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED: + // no-op + return 0, nil, nil + case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED: + t.MemoryManager().EnableMembarrierPrivate() + return 0, nil, nil + case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE, linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE: + // We're aware of these, but they aren't implemented since no platform + // supports them yet. + t.Kernel().EmitUnimplementedEvent(t) + fallthrough + default: + return 0, nil, syserror.EINVAL + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 0df3bd449..c50fd97eb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -163,6 +163,7 @@ func Override() { // Override ARM64. s = linux.ARM64 + s.Table[2] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) s.Table[5] = syscalls.Supported("setxattr", SetXattr) s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr) @@ -200,6 +201,7 @@ func Override() { s.Table[44] = syscalls.Supported("fstatfs", Fstatfs) s.Table[45] = syscalls.Supported("truncate", Truncate) s.Table[46] = syscalls.Supported("ftruncate", Ftruncate) + s.Table[47] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil) s.Table[48] = syscalls.Supported("faccessat", Faccessat) s.Table[49] = syscalls.Supported("chdir", Chdir) s.Table[50] = syscalls.Supported("fchdir", Fchdir) @@ -221,12 +223,14 @@ func Override() { s.Table[68] = syscalls.Supported("pwrite64", Pwrite64) s.Table[69] = syscalls.Supported("preadv", Preadv) s.Table[70] = syscalls.Supported("pwritev", Pwritev) + s.Table[71] = syscalls.Supported("sendfile", Sendfile) s.Table[72] = syscalls.Supported("pselect", Pselect) s.Table[73] = syscalls.Supported("ppoll", Ppoll) s.Table[74] = syscalls.Supported("signalfd4", Signalfd4) s.Table[76] = syscalls.Supported("splice", Splice) s.Table[77] = syscalls.Supported("tee", Tee) s.Table[78] = syscalls.Supported("readlinkat", Readlinkat) + s.Table[79] = syscalls.Supported("newfstatat", Newfstatat) s.Table[80] = syscalls.Supported("fstat", Fstat) s.Table[81] = syscalls.Supported("sync", Sync) s.Table[82] = syscalls.Supported("fsync", Fsync) @@ -251,8 +255,10 @@ func Override() { s.Table[210] = syscalls.Supported("shutdown", Shutdown) s.Table[211] = syscalls.Supported("sendmsg", SendMsg) s.Table[212] = syscalls.Supported("recvmsg", RecvMsg) + s.Table[213] = syscalls.Supported("readahead", Readahead) s.Table[221] = syscalls.Supported("execve", Execve) s.Table[222] = syscalls.Supported("mmap", Mmap) + s.Table[223] = syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil) s.Table[242] = syscalls.Supported("accept4", Accept4) s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg) s.Table[267] = syscalls.Supported("syncfs", Syncfs) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 8093ca55c..c855608db 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -92,7 +92,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index ea0c5413d..8db70a700 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -84,8 +84,8 @@ type VectorisedView struct { size int } -// NewVectorisedView creates a new vectorised view from an already-allocated slice -// of View and sets its size. +// NewVectorisedView creates a new vectorised view from an already-allocated +// slice of View and sets its size. func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } @@ -170,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) { } // Clone returns a clone of this VectorisedView. -// If the buffer argument is large enough to contain all the Views of this VectorisedView, -// the method will avoid allocations and use the buffer to store the Views of the clone. +// If the buffer argument is large enough to contain all the Views of this +// VectorisedView, the method will avoid allocations and use the buffer to +// store the Views of the clone. func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } @@ -209,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) { return newFirst, true } -// Size returns the size in bytes of the entire content stored in the vectorised view. +// Size returns the size in bytes of the entire content stored in the +// vectorised view. func (vv *VectorisedView) Size() int { return vv.size } @@ -222,6 +224,12 @@ func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } + return vv.ToOwnedView() +} + +// ToOwnedView returns a single view containing the content of the vectorised +// view that vv does not own. +func (vv *VectorisedView) ToOwnedView() View { u := make([]byte, 0, vv.size) for _, v := range vv.views { u = append(u, v...) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 19627fa9b..d4d785cca 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -118,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker { v = ip.HopLimit() } if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) + t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) + } + } +} + +// IPFullLength creates a checker for the full IP packet length. The +// expected size is checked against both the Total Length in the +// header and the number of bytes received. +func IPFullLength(packetLength uint16) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + var v uint16 + var l uint16 + switch ip := h[0].(type) { + case header.IPv4: + v = ip.TotalLength() + l = uint16(len(ip)) + case header.IPv6: + v = ip.PayloadLength() + header.IPv6FixedHeaderSize + l = uint16(len(ip)) + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) + } + if l != packetLength { + t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) + } + if v != packetLength { + t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) + } + } +} + +// IPv4HeaderLength creates a checker that checks the IPv4 Header length. +func IPv4HeaderLength(headerLength int) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + switch ip := h[0].(type) { + case header.IPv4: + if hl := ip.HeaderLength(); hl != uint8(headerLength) { + t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) + } + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) } } } // PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { +func PayloadLen(payloadLength int) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) + if l := len(h[0].Payload()); l != payloadLength { + t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) + } + } +} + +// IPv4Options returns a checker that checks the options in an IPv4 packet. +func IPv4Options(want []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + options := ip.Options() + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(options) == 0 { + return + } + if diff := cmp.Diff(want, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) } } } @@ -139,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) } } } @@ -154,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) } } } @@ -208,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker { t.Helper() if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) } } } @@ -234,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { t.Helper() if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } ipv6Frag := header.IPv6Fragment(h[0].Payload()) @@ -261,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } // Verify the checksum. @@ -297,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } udp := header.UDP(last.Payload()) @@ -316,7 +380,7 @@ func SrcPort(port uint16) TransportChecker { t.Helper() if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) + t.Errorf("Bad source port, got = %d, want = %d", p, port) } } } @@ -327,7 +391,7 @@ func DstPort(port uint16) TransportChecker { t.Helper() if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) + t.Errorf("Bad destination port, got = %d, want = %d", p, port) } } } @@ -359,7 +423,7 @@ func TCPSeqNum(seq uint32) TransportChecker { } if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) + t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) } } } @@ -375,7 +439,7 @@ func TCPAckNum(seq uint32) TransportChecker { } if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) + t.Errorf("Bad ack number, got = %d, want = %d", s, seq) } } } @@ -492,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { case header.TCPOptionMSS: v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) } foundMSS = true i += 4 @@ -502,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { } v := int(opts[i+2]) if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) + t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) } foundWS = true i += 3 @@ -551,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { t.Error("TS option specified but the timestamp value is zero") } if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) } if wantOpts.SACKPermitted && !foundSACKPermitted { t.Errorf("SACKPermitted option not found. Options: %x", opts) @@ -589,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) } if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = binary.BigEndian.Uint32(opts[i+6:]) @@ -609,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) } if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) } if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) } } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not -// contain any SACK blocks in the TCP options. +// TCPNoSACKBlockChecker creates a checker that verifies that the segment does +// not contain any SACK blocks in the TCP options. func TCPNoSACKBlockChecker() TransportChecker { return TCPSACKBlockChecker(nil) } @@ -679,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) } } } @@ -695,8 +759,8 @@ func Payload(want []byte) TransportChecker { } } -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and -// potentially additional ICMPv4 header fields. +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 +// and potentially additional ICMPv4 header fields. func ICMPv4(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -724,10 +788,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -739,10 +803,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. +func ICMPv4Ident(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Ident(); got != want { + t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. +func ICMPv4Seq(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Sequence(); got != want { + t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. +// This assumes that the payload exactly makes up the rest of the slice. +func ICMPv4Checksum() TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + heldChecksum := icmpv4.Checksum() + icmpv4.SetChecksum(0) + newChecksum := ^header.Checksum(icmpv4, 0) + icmpv4.SetChecksum(heldChecksum) + if heldChecksum != newChecksum { + t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) + } + } +} + +// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. +func ICMPv4Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + payload := icmpv4.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } } } @@ -782,10 +912,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -797,10 +927,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) } } } diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index 1193f1d7d..f7a4fbde1 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -24,6 +24,26 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) +// NullClock implements a clock that never advances. +type NullClock struct{} + +var _ tcpip.Clock = (*NullClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (*NullClock) NowNanoseconds() int64 { + return 0 +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (*NullClock) NowMonotonic() int64 { + return 0 +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { + return nil +} + // ManualClock implements tcpip.Clock and only advances manually with Advance // method. type ManualClock struct { diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c00bcadfb..504408878 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -126,15 +126,6 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } -// SetPointer sets the pointer field in a Parameter error packet. -// This is the first byte of the type specific data field. -func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } - -// SetTypeSpecific sets the full 32 bit type specific data field. -func (b ICMPv4) SetTypeSpecific(val uint32) { - binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val) -} - // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 4eb5abd79..6be31beeb 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -156,9 +156,14 @@ const ( // ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. const ( + // ICMPv6ErroneousHeader indicates an erroneous header field was encountered. ICMPv6ErroneousHeader ICMPv6Code = 0 - ICMPv6UnknownHeader ICMPv6Code = 1 - ICMPv6UnknownOption ICMPv6Code = 2 + + // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered. + ICMPv6UnknownHeader ICMPv6Code = 1 + + // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered. + ICMPv6UnknownOption ICMPv6Code = 2 ) // ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use @@ -177,7 +182,12 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } -// SetTypeSpecific sets the full 32 bit type specific data field. +// TypeSpecific returns the type specific data field. +func (b ICMPv6) TypeSpecific() uint32 { + return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) +} + +// SetTypeSpecific sets the type specific data field. func (b ICMPv6) SetTypeSpecific(val uint32) { binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index b07d9991d..4c6e4be64 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,10 +16,29 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" ) +// RFC 971 defines the fields of the IPv4 header on page 11 using the following +// diagram: ("Figure 4") +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Identification |Flags| Fragment Offset | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Time to Live | Protocol | Header Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Source Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Destination Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options | Padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// const ( versIHL = 0 tos = 1 @@ -33,6 +52,7 @@ const ( checksum = 10 srcAddr = 12 dstAddr = 16 + options = 20 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -76,7 +96,8 @@ type IPv4Fields struct { // IPv4 represents an ipv4 header stored in a byte array. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. -// Always call IsValid() to validate an instance of IPv4 before using other methods. +// Always call IsValid() to validate an instance of IPv4 before using other +// methods. type IPv4 []byte const ( @@ -151,13 +172,44 @@ func IPVersion(b []byte) int { if len(b) < versIHL+1 { return -1 } - return int(b[versIHL] >> 4) + return int(b[versIHL] >> ipVersionShift) } +// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits +// of the first byte, and is counted in multiples of 4 bytes. +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// (...) +// Version: 4 bits +// The Version field indicates the format of the internet header. This +// document describes version 4. +// +// IHL: 4 bits +// Internet Header Length is the length of the internet header in 32 +// bit words, and thus points to the beginning of the data. Note that +// the minimum value for a correct header is 5. +// +const ( + ipVersionShift = 4 + ipIHLMask = 0x0f + IPv4IHLStride = 4 +) + // HeaderLength returns the value of the "header length" field of the ipv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { - return (b[versIHL] & 0xf) * 4 + return (b[versIHL] & ipIHLMask) * IPv4IHLStride +} + +// SetHeaderLength sets the value of the "Internet Header Length" field. +func (b IPv4) SetHeaderLength(hdrLen uint8) { + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize)) + } + b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } // ID returns the value of the identifier field of the ipv4 header. @@ -211,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// Options returns a a buffer holding the options. +func (b IPv4) Options() []byte { + hdrLen := b.HeaderLength() + return b[options:hdrLen:hdrLen] +} + // TransportProtocol implements Network.TransportProtocol. func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber { return tcpip.TransportProtocolNumber(b.Protocol()) @@ -236,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } +// SetTTL sets the "Time to Live" field of the IPv4 header. +func (b IPv4) SetTTL(v byte) { + b[ttl] = v +} + // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) @@ -276,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 { // Encode encodes all the fields of the ipv4 header. func (b IPv4) Encode(i *IPv4Fields) { - b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) + b.SetHeaderLength(i.IHL) b[tos] = i.TOS b.SetTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 0761a1807..c5d8a3456 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -34,6 +34,9 @@ const ( hopLimit = 7 v6SrcAddr = 8 v6DstAddr = v6SrcAddr + IPv6AddressSize + + // IPv6FixedHeaderSize is the size of the fixed header. + IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -69,7 +72,7 @@ type IPv6 []byte const ( // IPv6MinimumSize is the minimum size of a valid IPv6 packet. - IPv6MinimumSize = 40 + IPv6MinimumSize = IPv6FixedHeaderSize // IPv6AddressSize is the size, in bytes, of an IPv6 address. IPv6AddressSize = 16 @@ -306,14 +309,21 @@ func IsV6UnicastAddress(addr tcpip.Address) bool { return addr[0] != 0xff } +const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { - const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// IsSolicitedNodeAddr determines whether the address is a solicited-node +// multicast address. +func IsSolicitedNodeAddr(addr tcpip.Address) bool { + return solicitedNodeMulticastPrefix == addr[:len(addr)-3] +} + // EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 // from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. // diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 3499d8399..583c2c5d3 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { // obtained before modification is no longer used. type IPv6OptionsExtHdrOptionsIterator struct { reader bytes.Reader + + // optionOffset is the number of bytes from the first byte of the + // options field to the beginning of the current option. + optionOffset uint32 + + // nextOptionOffset is the offset of the next option. + nextOptionOffset uint32 +} + +// OptionOffset returns the number of bytes parsed while processing the +// option field of the current Extension Header. +func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 { + return i.optionOffset } // IPv6OptionUnknownAction is the action that must be taken if the processing @@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} // the options data, or an error occured. func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { for { + i.optionOffset = i.nextOptionOffset temp, err := i.reader.ReadByte() if err != nil { // If we can't read the first byte of a new option, then we know the @@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // know the option does not have Length and Data fields. End processing of // the Pad1 option and continue processing the buffer as a new option. if id == ipv6Pad1ExtHdrOptionIdentifier { + i.nextOptionOffset = i.optionOffset + 1 continue } @@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) } - // Special-case the variable length padding option to avoid a copy. - if id == ipv6PadNExtHdrOptionIdentifier { - // Do we have enough bytes in the reader for the PadN option? - if n := i.reader.Len(); n < int(length) { - // Reset the reader to effectively consume the remaining buffer. - i.reader.Reset(nil) - - // We return the same error as if we failed to read a non-padding option - // so consumers of this iterator don't need to differentiate between - // padding and non-padding options. - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) - } + // Do we have enough bytes in the reader for the next option? + if n := i.reader.Len(); n < int(length) { + // Reset the reader to effectively consume the remaining buffer. + i.reader.Reset(nil) + + // We return the same error as if we failed to read a non-padding option + // so consumers of this iterator don't need to differentiate between + // padding and non-padding options. + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) + } + + i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */ + switch id { + case ipv6PadNExtHdrOptionIdentifier: + // Special-case the variable length padding option to avoid a copy. if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil { panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } - - // End processing of the PadN option and continue processing the buffer as - // a new option. continue - } - - bytes := make([]byte, length) - if n, err := io.ReadFull(&i.reader, bytes); err != nil { - // io.ReadFull may return io.EOF if i.reader has been exhausted. We use - // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the - // Length field found in the option. - if err == io.EOF { - err = io.ErrUnexpectedEOF + default: + bytes := make([]byte, length) + if n, err := io.ReadFull(&i.reader, bytes); err != nil { + // io.ReadFull may return io.EOF if i.reader has been exhausted. We use + // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the + // Length field found in the option. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) } - - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) + return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } - - return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } } @@ -382,6 +396,29 @@ type IPv6PayloadIterator struct { // Indicates to the iterator that it should return the remaining payload as a // raw payload on the next call to Next. forceRaw bool + + // headerOffset is the offset of the beginning of the current extension + // header starting from the beginning of the fixed header. + headerOffset uint32 + + // parseOffset is the byte offset into the current extension header of the + // field we are currently examining. It can be added to the header offset + // if the absolute offset within the packet is required. + parseOffset uint32 + + // nextOffset is the offset of the next header. + nextOffset uint32 +} + +// HeaderOffset returns the offset to the start of the extension +// header most recently processed. +func (i IPv6PayloadIterator) HeaderOffset() uint32 { + return i.headerOffset +} + +// ParseOffset returns the number of bytes successfully parsed. +func (i IPv6PayloadIterator) ParseOffset() uint32 { + return i.headerOffset + i.parseOffset } // MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing @@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa nextHdrIdentifier: nextHdrIdentifier, payload: payload.Clone(nil), // We need a buffer of size 1 for calls to bufio.Reader.ReadByte. - reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + nextOffset: IPv6FixedHeaderSize, } } @@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Next is unable to return anything because the iterator has reached the end of // the payload, or an error occured. func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + i.headerOffset = i.nextOffset + i.parseOffset = 0 // We could be forced to return i as a raw header when the previous header was // a fragment extension header as the data following the fragment extension // header may not be complete. @@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { return IPv6RoutingExtHdr(bytes), false, nil case IPv6FragmentExtHdrIdentifier: var data [6]byte - // We ignore the returned bytes becauase we know the fragment extension + // We ignore the returned bytes because we know the fragment extension // header specific data will fit in data. nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) if err != nil { @@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP if err != nil { return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) } + i.parseOffset++ var length uint8 length, err = i.reader.ReadByte() i.payload.TrimFront(1) + if err != nil { if fragmentHdr { return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) @@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP length = 0 } + // Make parseOffset point to the first byte of the Extension Header + // specific data. + i.parseOffset++ + + // length is in 8 byte chunks but doesn't include the first one. + // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement + // in section 4.8 for new extension headers at the top of page 24. + // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet + // units, not including the first 8 octets. + i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit) + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded if bytes == nil { bytes = make([]byte, bytesLen) diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go index b5540bf66..17a49d4fa 100644 --- a/pkg/tcpip/header/ipversion_test.go +++ b/pkg/tcpip/header/ipversion_test.go @@ -22,7 +22,7 @@ import ( func TestIPv4(t *testing.T) { b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) + b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize}) const want = header.IPv4Version if v := header.IPVersion(b); v != want { diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 522135557..5ca75c834 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -139,6 +139,7 @@ traverseExtensions: // Returns true if the header was successfully parsed. func UDP(pkt *stack.PacketBuffer) bool { _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + pkt.TransportProtocolNumber = header.UDPProtocolNumber return ok } @@ -162,5 +163,6 @@ func TCP(pkt *stack.PacketBuffer) bool { } _, ok = pkt.TransportHeader().Consume(hdrLen) + pkt.TransportProtocolNumber = header.TCPProtocolNumber return ok } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 46083925c..59710352b 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -9,6 +9,7 @@ go_test( "ip_test.go", ], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -17,6 +18,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", ], diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index b025bb087..b47a7be51 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -18,6 +18,8 @@ package arp import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -33,15 +35,57 @@ const ( ProtocolAddress = tcpip.Address("arp") ) -// endpoint implements stack.NetworkEndpoint. +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - protocol *protocol - nicID tcpip.NICID + stack.AddressableEndpointState + + protocol *protocol + + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + nic stack.NetworkInterface linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler } +func (e *endpoint) Enable() *tcpip.Error { + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + e.setEnabled(true) + return nil +} + +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +func (e *endpoint) setEnabled(v bool) { + if v { + atomic.StoreUint32(&e.enabled, 1) + } else { + atomic.StoreUint32(&e.enabled, 0) + } +} + +func (e *endpoint) Disable() { + e.setEnabled(false) +} + // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. func (e *endpoint) DefaultTTL() uint8 { return 0 @@ -52,19 +96,13 @@ func (e *endpoint) MTU() uint32 { return lmtu - uint32(e.MaxHeaderLength()) } -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID -} - -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported @@ -85,6 +123,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { + return + } + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return @@ -95,15 +137,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { localAddr := tcpip.Address(h.ProtocolAddressTarget()) if e.nud == nil { - if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } @@ -129,7 +171,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) return } @@ -161,14 +203,16 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ protocol: p, - nicID: nicID, - linkEP: sender, + nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, } + e.AddressableEndpointState.Init(e) + return e } // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 96c5f42f8..47fb63290 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -29,6 +29,8 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", ], ) @@ -43,5 +45,8 @@ go_test( library = ":fragmentation", deps = [ "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", + "//pkg/tcpip/network/testutil", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6a4843f92..888ad62a3 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -13,7 +13,7 @@ // limitations under the License. // Package fragmentation contains the implementation of IP fragmentation. -// It is based on RFC 791 and RFC 815. +// It is based on RFC 791, RFC 815 and RFC 8200. package fragmentation import ( @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -81,6 +82,8 @@ type Fragmentation struct { size int timeout time.Duration blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job } // NewFragmentation creates a new Fragmentation. @@ -97,7 +100,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -110,13 +113,17 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea blockSize = minBlockSize } - return &Fragmentation{ + f := &Fragmentation{ reassemblers: make(map[FragmentID]*reassembler), highLimit: highMemoryLimit, lowLimit: lowMemoryLimit, timeout: reassemblingTimeout, blockSize: blockSize, + clock: clock, } + f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) + + return f } // Process processes an incoming fragment belonging to an ID and returns a @@ -155,15 +162,17 @@ func (f *Fragmentation) Process( f.mu.Lock() r, ok := f.reassemblers[id] - if ok && r.tooOld(f.timeout) { - // This is very likely to be an id-collision or someone performing a slow-rate attack. - f.release(r) - ok = false - } if !ok { - r = newReassembler(id) + r = newReassembler(id, f.clock) f.reassemblers[id] = r + wasEmpty := f.rList.Empty() f.rList.PushFront(r) + if wasEmpty { + // If we have just pushed a first reassembler into an empty list, we + // should kickstart the release job. The release job will keep + // rescheduling itself until the list becomes empty. + f.releaseReassemblersLocked() + } } f.mu.Unlock() @@ -211,3 +220,102 @@ func (f *Fragmentation) release(r *reassembler) { f.size = 0 } } + +// releaseReassemblersLocked releases already-expired reassemblers, then +// schedules the job to call back itself for the remaining reassemblers if +// any. This function must be called with f.mu locked. +func (f *Fragmentation) releaseReassemblersLocked() { + now := f.clock.NowMonotonic() + for { + // The reassembler at the end of the list is the oldest. + r := f.rList.Back() + if r == nil { + // The list is empty. + break + } + elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + if f.timeout > elapsed { + // If the oldest reassembler has not expired, schedule the release + // job so that this function is called back when it has expired. + f.releaseJob.Schedule(f.timeout - elapsed) + break + } + // If the oldest reassembler has already expired, release it. + f.release(r) + } +} + +// PacketFragmenter is the book-keeping struct for packet fragmentation. +type PacketFragmenter struct { + transportHeader buffer.View + data buffer.VectorisedView + reserve int + innerMTU int + fragmentCount int + currentFragment int + fragmentOffset int +} + +// MakePacketFragmenter prepares the struct needed for packet fragmentation. +// +// pkt is the packet to be fragmented. +// +// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// have. +// +// reserve is the number of bytes that should be reserved for the headers in +// each generated fragment. +func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { + // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be + // repeated in each fragment. However we do not currently support any header + // of that kind yet, so the following computation is valid for both IPv4 and + // IPv6. + // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are + // supported for outbound packets, the fragmentable data should not include + // these headers. + var fragmentableData buffer.VectorisedView + fragmentableData.AppendView(pkt.TransportHeader().View()) + fragmentableData.Append(pkt.Data) + fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + + return PacketFragmenter{ + data: fragmentableData, + reserve: reserve, + innerMTU: innerMTU, + fragmentCount: fragmentCount, + } +} + +// BuildNextFragment returns a packet with the payload of the next fragment, +// along with the fragment's offset, the number of bytes copied and a boolean +// indicating if there are more fragments left or not. If this function is +// called again after it indicated that no more fragments were left, it will +// panic. +// +// Note that the returned packet will not have its network and link headers +// populated, but space for them will be reserved. The transport header will be +// stored in the packet's data. +func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) { + if pf.currentFragment >= pf.fragmentCount { + panic("BuildNextFragment should not be called again after the last fragment was returned") + } + + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: pf.reserve, + }) + + // Copy data for the fragment. + copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + + offset := pf.fragmentOffset + pf.fragmentOffset += copied + pf.currentFragment++ + more := pf.currentFragment != pf.fragmentCount + + return fragPkt, offset, copied, more +} + +// RemainingFragmentCount returns the number of fragments left to be built. +func (pf *PacketFragmenter) RemainingFragmentCount() int { + return pf.fragmentCount - pf.currentFragment +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 416604659..31a1eb862 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -20,7 +20,10 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" ) // vv is a helper to build VectorisedView from different strings. @@ -95,7 +98,7 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) @@ -131,25 +134,126 @@ func TestFragmentationProcess(t *testing.T) { } func TestReassemblingTimeout(t *testing.T) { - timeout := time.Millisecond - f := NewFragmentation(minBlockSize, 1024, 512, timeout) - // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) - // Sleep more than the timeout. - time.Sleep(2 * timeout) - // Send another fragment that completes a packet. - // However, no packet should be reassembled because the fragment arrived after the timeout. - _, _, done, err := f.Process(FragmentID{}, 1, 1, false, 0xFF, vv(1, "1")) - if err != nil { - t.Fatalf("f.Process(0, 1, 1, false, 0xFF, vv(1, \"1\")) failed: %v", err) + const ( + reassemblyTimeout = time.Millisecond + protocol = 0xff + ) + + type fragment struct { + first uint16 + last uint16 + more bool + data string } - if done { - t.Errorf("Fragmentation does not respect the reassembling timeout.") + + type event struct { + // name is a nickname of this event. + name string + + // clockAdvance is a duration to advance the clock. The clock advances + // before a fragment specified in the fragment field is processed. + clockAdvance time.Duration + + // fragment is a fragment to process. This can be nil if there is no + // fragment to process. + fragment *fragment + + // expectDone is true if the fragmentation instance should report the + // reassembly is done after the fragment is processd. + expectDone bool + + // sizeAfterEvent is the expected size of the fragmentation instance after + // the event. + sizeAfterEvent int + } + + half1 := &fragment{first: 0, last: 0, more: true, data: "0"} + half2 := &fragment{first: 1, last: 1, more: false, data: "1"} + + tests := []struct { + name string + events []event + }{ + { + name: "half1 and half2 are reassembled successfully", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2", + fragment: half2, + expectDone: true, + sizeAfterEvent: 0, + }, + }, + }, + { + name: "half1 timeout, half2 timeout", + events: []event{ + { + name: "half1", + fragment: half1, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half1 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + { + name: "half2", + fragment: half2, + expectDone: false, + sizeAfterEvent: 1, + }, + { + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + sizeAfterEvent: 1, + }, + { + name: "half2 reassembly timeout", + clockAdvance: 1, + sizeAfterEvent: 0, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + for _, event := range test.events { + clock.Advance(event.clockAdvance) + if frag := event.fragment; frag != nil { + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + if err != nil { + t.Fatalf("%s: f.Process failed: %s", event.name, err) + } + if done != event.expectDone { + t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) + } + } + if got, want := f.size, event.sizeAfterEvent; got != want { + t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want) + } + } + }) } } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) // Send first fragment with id = 1. @@ -173,7 +277,7 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) // Send the same packet again. @@ -268,7 +372,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{}) _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) @@ -279,3 +383,113 @@ func TestErrors(t *testing.T) { }) } } + +type fragmentInfo struct { + remaining int + copied int + offset int + more bool +} + +func TestPacketFragmenter(t *testing.T) { + const ( + reserve = 60 + proto = 0 + ) + + tests := []struct { + name string + innerMTU int + transportHeaderLen int + payloadSize int + wantFragments []fragmentInfo + }{ + { + name: "Packet exactly fits in MTU", + innerMTU: 1280, + transportHeaderLen: 0, + payloadSize: 1280, + wantFragments: []fragmentInfo{ + {remaining: 0, copied: 1280, offset: 0, more: false}, + }, + }, + { + name: "Packet exactly does not fit in MTU", + innerMTU: 1000, + transportHeaderLen: 0, + payloadSize: 1001, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 1000, offset: 0, more: true}, + {remaining: 0, copied: 1, offset: 1000, more: false}, + }, + }, + { + name: "Packet has a transport header", + innerMTU: 560, + transportHeaderLen: 40, + payloadSize: 560, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 560, offset: 0, more: true}, + {remaining: 0, copied: 40, offset: 560, more: false}, + }, + }, + { + name: "Packet has a huge transport header", + innerMTU: 500, + transportHeaderLen: 1300, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {remaining: 3, copied: 500, offset: 0, more: true}, + {remaining: 2, copied: 500, offset: 500, more: true}, + {remaining: 1, copied: 500, offset: 1000, more: true}, + {remaining: 0, copied: 300, offset: 1500, more: false}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) + var originalPayload buffer.VectorisedView + originalPayload.AppendView(pkt.TransportHeader().View()) + originalPayload.Append(pkt.Data) + var reassembledPayload buffer.VectorisedView + pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + for i := 0; ; i++ { + fragPkt, offset, copied, more := pf.BuildNextFragment() + wantFragment := test.wantFragments[i] + if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { + t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) + } + if copied != wantFragment.copied { + t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) + } + if offset != wantFragment.offset { + t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) + } + if more != wantFragment.more { + t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) + } + if got := fragPkt.Size(); got > test.innerMTU { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + } + if got := fragPkt.AvailableHeaderBytes(); got != reserve { + t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) + } + if got := fragPkt.TransportHeader().View().Size(); got != 0 { + t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) + } + reassembledPayload.Append(fragPkt.Data) + if !more { + if i != len(test.wantFragments)-1 { + t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) + } + break + } + } + if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { + t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index f044867dc..9bb051a30 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "time" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -40,15 +40,15 @@ type reassembler struct { deleted int heap fragHeap done bool - creationTime time.Time + creationTime int64 } -func newReassembler(id FragmentID) *reassembler { +func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, holes: make([]hole, 0, 16), heap: make(fragHeap, 0, 8), - creationTime: time.Now(), + creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, @@ -116,10 +116,6 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf return res, r.proto, true, consumed, nil } -func (r *reassembler) tooOld(timeout time.Duration) bool { - return time.Now().Sub(r.creationTime) > timeout -} - func (r *reassembler) checkDoneOrMark() bool { r.mu.Lock() prev := r.done diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index dff7c9dcb..a0a04a027 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -18,6 +18,8 @@ import ( "math" "reflect" "testing" + + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) type updateHolesInput struct { @@ -94,7 +96,7 @@ var holesTestCases = []struct { func TestUpdateHoles(t *testing.T) { for _, c := range holesTestCases { - r := newReassembler(FragmentID{}) + r := newReassembler(FragmentID{}, &faketime.NullClock{}) for _, i := range c.in { r.updateHoles(i.first, i.last, i.more) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4640ca95c..6861cfdaf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -17,6 +17,7 @@ package ip_test import ( "testing" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -25,26 +26,35 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( - localIpv4Addr = "\x0a\x00\x00\x01" - localIpv4PrefixLen = 24 - remoteIpv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - localIpv6PrefixLen = 120 - remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" - nicID = 1 + localIPv4Addr = "\x0a\x00\x00\x01" + remoteIPv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + nicID = 1 ) +var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv4Addr, + PrefixLen: 24, +} + +var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: localIPv6Addr, + PrefixLen: 120, +} + // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. // The former is used to pretend that it's a link endpoint so that we can // inspect packets written by the network endpoints. The latter is used to @@ -225,7 +235,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStack(t *testing.T) *stack.Stack { +func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) { t.Helper() s := stack.New(stack.Options{ @@ -237,22 +247,282 @@ func buildDummyStack(t *testing.T) *stack.Stack { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err) + v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v4Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err) + v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix} + if err := s.AddProtocolAddress(nicID, v6Addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err) } + return s, e +} + +func buildDummyStack(t *testing.T) *stack.Stack { + t.Helper() + + s, _ := buildDummyStackWithLinkEndpoint(t) return s } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + tester testObject + + mu struct { + sync.RWMutex + disabled bool + } +} + +func (*testInterface) ID() tcpip.NICID { + return nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (t *testInterface) Enabled() bool { + t.mu.RLock() + defer t.mu.RUnlock() + return !t.mu.disabled +} + +func (t *testInterface) setEnabled(v bool) { + t.mu.Lock() + defer t.mu.Unlock() + t.mu.disabled = !v +} + +func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { + return &t.tester +} + +func TestSourceAddressValidation(t *testing.T) { + rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv4Addr, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + tests := []struct { + name string + srcAddress tcpip.Address + rxICMP func(*channel.Endpoint, tcpip.Address) + valid bool + }{ + { + name: "IPv4 valid", + srcAddress: "\x01\x02\x03\x04", + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 valid", + srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10", + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv4ICMP, + valid: true, + }, + { + name: "IPv6 unspecified", + srcAddress: header.IPv4Any, + rxICMP: rxIPv6ICMP, + valid: true, + }, + { + name: "IPv4 multicast", + srcAddress: "\xe0\x00\x00\x01", + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv6 multicast", + srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + rxICMP: rxIPv6ICMP, + valid: false, + }, + { + name: "IPv4 broadcast", + srcAddress: header.IPv4Broadcast, + rxICMP: rxIPv4ICMP, + valid: false, + }, + { + name: "IPv4 subnet broadcast", + srcAddress: func() tcpip.Address { + subnet := localIPv4AddrWithPrefix.Subnet() + return subnet.Broadcast() + }(), + rxICMP: rxIPv4ICMP, + valid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, e := buildDummyStackWithLinkEndpoint(t) + test.rxICMP(e, test.srcAddress) + + var wantValid uint64 + if test.valid { + wantValid = 1 + } + + if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want { + t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid { + t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid) + } + }) + } +} + +func TestEnableWhenNICDisabled(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protocolFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protocolFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var nic testInterface + nic.setEnabled(false) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory}, + }) + p := s.NetworkProtocolInstance(test.protoNum) + + // We pass nil for all parameters except the NetworkInterface and Stack + // since Enable only depends on these. + ep := p.NewEndpoint(&nic, nil, nil, nil) + + // The endpoint should initially be disabled, regardless the NIC's enabled + // status. + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Attempting to enable the endpoint while the NIC is disabled should + // fail. + nic.setEnabled(false) + if err := ep.Enable(); err != tcpip.ErrNotPermitted { + t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted) + } + // ep should consider the NIC's enabled status when determining its own + // enabled status so we "enable" the NIC to read just the endpoint's + // enabled status. + nic.setEnabled(true) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Enabling the interface after the NIC has been enabled should succeed. + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + if !ep.Enabled() { + t.Fatal("got ep.Enabled() = false, want = true") + } + + // ep should consider the NIC's enabled status when determining its own + // enabled status. + nic.setEnabled(false) + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + + // Disabling the endpoint when the NIC is enabled should make the endpoint + // disabled. + nic.setEnabled(true) + ep.Disable() + if ep.Enabled() { + t.Fatal("got ep.Enabled() = true, want = false") + } + }) + } +} + func TestIPv4Send(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() // Allocate and initialize the payload view. @@ -268,12 +538,12 @@ func TestIPv4Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv4Addr - o.dstAddr = remoteIpv4Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIPv4Addr + nic.tester.dstAddr = remoteIPv4Addr + nic.tester.contents = payload - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -287,12 +557,21 @@ func TestIPv4Send(t *testing.T) { } func TestIPv4Receive(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + totalLen := header.IPv4MinimumSize + 30 view := buffer.NewView(totalLen) ip := header.IPv4(view) @@ -301,8 +580,8 @@ func TestIPv4Receive(t *testing.T) { TotalLength: uint16(totalLen), TTL: 20, Protocol: 10, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. @@ -311,12 +590,12 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[header.IPv4MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = view[header.IPv4MinimumSize:totalLen] - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -327,8 +606,8 @@ func TestIPv4Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -352,18 +631,26 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") + r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") if err != nil { t.Fatal(err) } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize view := buffer.NewView(dataOffset + 8) @@ -375,7 +662,7 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: uint8(header.ICMPv4ProtocolNumber), SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: localIpv4Addr, + DstAddr: localIPv4Addr, }) // Create the ICMP header. @@ -393,8 +680,8 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: c.fragmentOffset, - SrcAddr: localIpv4Addr, - DstAddr: remoteIpv4Addr, + SrcAddr: localIPv4Addr, + DstAddr: remoteIPv4Addr, }) // Make payload be non-zero. @@ -404,28 +691,37 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } } func TestIPv4FragmentationReceive(t *testing.T) { - o := testObject{t: t, v4: true} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + v4: true, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + totalLen := header.IPv4MinimumSize + 24 frag1 := buffer.NewView(totalLen) @@ -437,8 +733,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { Protocol: 10, FragmentOffset: 0, Flags: header.IPv4FlagMoreFragments, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -453,8 +749,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: 24, - SrcAddr: remoteIpv4Addr, - DstAddr: localIpv4Addr, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -462,12 +758,12 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv4Addr - o.dstAddr = localIpv4Addr - o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv4Addr + nic.tester.dstAddr = localIPv4Addr + nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -480,8 +776,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + if nic.tester.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) } // Send second segment. @@ -492,18 +788,26 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } func TestIPv6Send(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, nil) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + // Allocate and initialize the payload view. payload := buffer.NewView(100) for i := 0; i < len(payload); i++ { @@ -517,12 +821,12 @@ func TestIPv6Send(t *testing.T) { }) // Issue the write. - o.protocol = 123 - o.srcAddr = localIpv6Addr - o.dstAddr = remoteIpv6Addr - o.contents = payload + nic.tester.protocol = 123 + nic.tester.srcAddr = localIPv6Addr + nic.tester.dstAddr = remoteIPv6Addr + nic.tester.contents = payload - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -536,12 +840,20 @@ func TestIPv6Send(t *testing.T) { } func TestIPv6Receive(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + totalLen := header.IPv6MinimumSize + 30 view := buffer.NewView(totalLen) ip := header.IPv6(view) @@ -549,8 +861,8 @@ func TestIPv6Receive(t *testing.T) { PayloadLength: uint16(totalLen - header.IPv6MinimumSize), NextHeader: 10, HopLimit: 20, - SrcAddr: remoteIpv6Addr, - DstAddr: localIpv6Addr, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, }) // Make payload be non-zero. @@ -559,12 +871,12 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[header.IPv6MinimumSize:totalLen] + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv6Addr + nic.tester.dstAddr = localIPv6Addr + nic.tester.contents = view[header.IPv6MinimumSize:totalLen] - r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { t.Fatalf("could not find route: %v", err) } @@ -576,8 +888,8 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if o.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + if nic.tester.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) } } @@ -608,7 +920,7 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } r, err := buildIPv6Route( - localIpv6Addr, + localIPv6Addr, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", ) if err != nil { @@ -616,12 +928,20 @@ func TestIPv6ReceiveControl(t *testing.T) { } for _, c := range cases { t.Run(c.name, func(t *testing.T) { - o := testObject{t: t} s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s) + nic := testInterface{ + tester: testObject{ + t: t, + }, + } + ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize @@ -635,7 +955,7 @@ func TestIPv6ReceiveControl(t *testing.T) { NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, SrcAddr: outerSrcAddr, - DstAddr: localIpv6Addr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -651,8 +971,8 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: 100, NextHeader: 10, HopLimit: 20, - SrcAddr: localIpv6Addr, - DstAddr: remoteIpv6Addr, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, }) // Build the fragmentation header if needed. @@ -674,19 +994,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - o.protocol = 10 - o.srcAddr = remoteIpv6Addr - o.dstAddr = localIpv6Addr - o.contents = view[dataOffset:] - o.typ = c.expectedTyp - o.extra = c.expectedExtra + nic.tester.protocol = 10 + nic.tester.srcAddr = remoteIPv6Addr + nic.tester.dstAddr = localIPv6Addr + nic.tester.contents = view[dataOffset:] + nic.tester.typ = c.expectedTyp + nic.tester.extra = c.expectedExtra // Set ICMPv6 checksum. - icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; o.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + if want := c.expectedCount; nic.tester.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) } }) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index f9c2aa980..ee2c23e91 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -27,12 +28,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3e5cf2ad9..eab9a530c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,8 @@ package ipv4 import ( + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -40,7 +42,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. src := hdr.SourceAddress() - if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -77,27 +79,27 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Echo.Increment() // Only send a reply if the checksum is valid. - wantChecksum := h.Checksum() - // Reset the checksum field to 0 to can calculate the proper - // checksum. We'll have to reset this before we hand the packet - // off. + headerChecksum := h.Checksum() h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - if gotChecksum != wantChecksum { - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) + calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(headerChecksum) + if calculatedChecksum != headerChecksum { + // It's possible that a raw socket still expects to receive this. e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } - // Make a copy of data before pkt gets sent to raw socket. - // DeliverTransportPacket will take ownership of pkt. - replyData := pkt.Data.Clone(nil) - replyData.TrimFront(header.ICMPv4MinimumSize) + // DeliverTransportPacket will take ownership of pkt so don't use it beyond + // this point. Make a deep copy of the data before pkt gets sent as we will + // be modifying fields. + // + // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no + // waiting endpoints. Consider moving responsibility for doing the copy to + // DeliverTransportPacket so that is is only done when needed. + replyData := pkt.Data.ToOwnedView() + replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) remoteLinkAddr := r.RemoteLinkAddress @@ -110,7 +112,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { localAddr = "" } - r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -120,29 +122,49 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // Use the remote link address from the incoming packet. r.ResolveWith(remoteLinkAddr) - // Prepare a reply packet. - icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) - copy(icmpHdr, h) - icmpHdr.SetType(header.ICMPv4EchoReply) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) - dataVV := buffer.View(icmpHdr).ToVectorisedView() - dataVV.Append(replyData) + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the + // header information, we may have to change this code to handle the + // ICMP header no longer being in the data buffer. + + // Because IP and ICMP are so closely intertwined, we need to handcraft our + // IP header to be able to follow RFC 792. The wording on page 13 is as + // follows: + // IP Fields: + // Addresses + // The address of the source in an echo message will be the + // destination of the echo reply message. To form an echo reply + // message, the source and destination addresses are simply reversed, + // the type code changed to 0, and the checksum recomputed. + // + // This was interpreted by early implementors to mean that all options must + // be copied from the echo request IP header to the echo reply IP header + // and this behaviour is still relied upon by some applications. + // + // Create a copy of the IP header we received, options and all, and change + // The fields we need to alter. + // + // We need to produce the entire packet in the data segment in order to + // use WriteHeaderIncludedPacket(). + replyIPHdr.SetSourceAddress(r.LocalAddress) + replyIPHdr.SetDestinationAddress(r.RemoteAddress) + replyIPHdr.SetTTL(r.DefaultTTL()) + + replyICMPHdr := header.ICMPv4(replyData) + replyICMPHdr.SetType(header.ICMPv4EchoReply) + replyICMPHdr.SetChecksum(0) + replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0)) + + replyVV := buffer.View(replyIPHdr).ToVectorisedView() + replyVV.AppendView(replyData) replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: dataVV, + Data: replyVV, }) - // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header - // information we will have to change this code to handle the ICMP header - // no longer being in the data buffer. replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // Send out the reply packet. + + // The checksum will be calculated so we don't need to do it here. sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ - Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), - TOS: stack.DefaultTOS, - }, replyPkt); err != nil { + if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return } @@ -211,6 +233,12 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonProtoUnreachable is an error where the transport protocol is +// not supported. +type icmpReasonProtoUnreachable struct{} + +func (*icmpReasonProtoUnreachable) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -287,8 +315,6 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // Assume any type we don't know about may be an error type. return nil } - } else if transportHeader.IsEmpty() { - return nil } // Now work out how much of the triggering packet we should return. @@ -336,13 +362,21 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc ReserveHeaderBytes: headerLen, Data: payload, }) + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + switch reason.(type) { + case *icmpReasonPortUnreachable: + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + case *icmpReasonProtoUnreachable: + icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter := sent.DstUnreachable icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + counter := sent.DstUnreachable if err := r.WritePacket( nil, /* gso */ diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 254d66147..79c939129 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -19,6 +19,7 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -47,22 +48,115 @@ const ( fragmentblockSize = 8 ) +var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() + +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID + nic stack.NetworkInterface linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol - stack *stack.Stack + + // enabled is set to 1 when the enpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + } } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ - nicID: nicID, - linkEP: linkEP, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, + linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, - stack: st, + } + e.mu.addressableEndpointState.Init(e) + return e +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil + } + + // Create an endpoint to receive broadcast packets on this interface. + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + return err + } + // We have no need for the address endpoint. + ep.DecRef() + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) + return err +} + +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() +} + +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) + } + + // The address may have already been removed. + if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } } @@ -77,20 +171,10 @@ func (e *endpoint) MTU() uint32 { return calculateMTU(e.linkEP.MTU()) } -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID -} - // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize + return e.linkEP.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // GSOMaxSize returns the maximum GSO packet size. @@ -106,98 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is already present in pkt.NetworkHeader. -// pkt.TransportHeader may be set. mtu includes the IP header and options. This -// does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { - // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.NetworkHeader().View()) - flags := ip.Flags() - - // Update mtu to take into account the header, which will exist in all - // fragments anyway. - innerMTU := mtu - int(ip.HeaderLength()) - - // Round the MTU down to align to 8 bytes. Then calculate the number of - // fragments. Calculate fragment sizes as in RFC791. - innerMTU &^= 7 - n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU - - outerMTU := innerMTU + int(ip.HeaderLength()) - offset := ip.FragmentOffset() - - // Keep the length reserved for link-layer, we need to create fragments with - // the same reserved length. - reservedForLink := pkt.AvailableHeaderBytes() - - // Destroy the packet, pull all payloads out for fragmentation. - transHeader, data := pkt.TransportHeader().View(), pkt.Data - - // Where possible, the first fragment that is sent has the same - // number of bytes reserved for header as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - transFitsFirst := len(transHeader) <= innerMTU - - for i := 0; i < n; i++ { - reserve := reservedForLink + int(ip.HeaderLength()) - if i == 0 && transFitsFirst { - // Reserve for transport header if it's going to be put in the first - // fragment. - reserve += len(transHeader) - } - fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: reserve, - }) - fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - - // Copy data for the fragment. - avail := innerMTU - - if n := len(transHeader); n > 0 { - if n > avail { - n = avail - } - if i == 0 && transFitsFirst { - copy(fragPkt.TransportHeader().Push(n), transHeader) - } else { - fragPkt.Data.AppendView(transHeader[:n:n]) - } - transHeader = transHeader[n:] - avail -= n - } - - if avail > 0 { - n := data.Size() - if n > avail { - n = avail - } - data.ReadToVV(&fragPkt.Data, n) - avail -= n - } - - copied := uint16(innerMTU - avail) - - // Set lengths in header and calculate checksum. - h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) - copy(h, ip) - if i != n-1 { - h.SetTotalLength(uint16(outerMTU)) - h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) - } else { - h.SetTotalLength(uint16(h.HeaderLength()) + copied) - h.SetFlagsFragmentOffset(flags, offset) - } - h.SetChecksum(0) - h.SetChecksum(^h.CalculateChecksum()) - offset += copied +// writePacketFragments fragments pkt and writes the results on the link +// endpoint. The IP header must already present in the original packet. The mtu +// is the maximum size of the packets. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error { + networkHeader := header.IPv4(pkt.NetworkHeader().View()) + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) - // Send out the fragment. + for { + fragPkt, more := buildNextFragment(&pf, networkHeader) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) return err } r.Stats().IP.PacketsSent.Increment() + if !more { + break + } } + return nil } @@ -219,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.NetworkProtocolNumber = ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. @@ -228,8 +240,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.stack.FindNICNameFromID(e.NICID()) - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() @@ -245,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) @@ -262,9 +274,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) + return e.writePacketFragments(r, gso, e.linkEP.MTU(), pkt) } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.Stats().IP.PacketsSent.Increment() @@ -285,16 +298,19 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkt = pkt.Next() } - nicName := e.stack.FindNICNameFromID(e.NICID()) + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) @@ -308,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -319,6 +335,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -377,23 +394,39 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } + if err := e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } r.Stats().IP.PacketsSent.Increment() - - return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { + return + } + h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + // As per RFC 1122 section 3.2.1.3: + // When a host sends any datagram, the IP source address MUST + // be one of its own IP addresses (but not a broadcast or + // multicast address). + if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } + // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.stack.IPTables() + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesInputDropped.Increment() @@ -449,6 +482,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } } + + r.Stats().IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport @@ -458,7 +493,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { e.handleICMP(r, pkt) return } - r.Stats().IP.PacketsDelivered.Increment() switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { case stack.TransportPacketHandled: @@ -469,23 +503,144 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC: 1122 Section 3.2.2.1 + // A host SHOULD generate Destination Unreachable messages with code: + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported + _ = returnError(r, &icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } } // Close cleans up resources associated with the endpoint. -func (e *endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + e.disableLocked() + e.mu.addressableEndpointState.Cleanup() +} + +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.RemovePermanentAddress(addr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + + loopback := e.nic.IsLoopback() + addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + // IPv4 has a notion of a subnet broadcast address and considers the + // loopback interface bound to an address's whole subnet (on linux). + return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) + }) + if addressEndpoint != nil { + return addressEndpoint + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) + if err != nil { + // AddAddress only returns an error if the address is already assigned, + // but we just checked above if the address exists so we expect no error. + panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) + } + return addressEndpoint +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV4MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) type protocol struct { - ids []uint32 - hashIV uint32 + stack *stack.Stack // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. defaultTTL uint32 + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + + ids []uint32 + hashIV uint32 + fragmentation *fragmentation.Fragmentation } @@ -558,6 +713,20 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + if v { + atomic.StoreUint32(&p.forwarding, 1) + } else { + atomic.StoreUint32(&p.forwarding, 0) + } +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { @@ -567,19 +736,41 @@ func calculateMTU(mtu uint32) uint32 { return mtu - header.IPv4MinimumSize } +// calculateFragmentInnerMTU calculates the maximum number of bytes of +// fragmentable data a fragment can have, based on the link layer mtu and pkt's +// network header size. +func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { + if mtu > MaxTotalSize { + mtu = MaxTotalSize + } + mtu -= uint32(pkt.NetworkHeader().View().Size()) + // Round the MTU down to align to 8 bytes. + mtu &^= 7 + return mtu +} + +// addressToUint32 translates an IPv4 address into its little endian uint32 +// representation. +// +// This function does the same thing as binary.LittleEndian.Uint32 but operates +// on a tcpip.Address (a string) without the need to convert it to a byte slice, +// which would cause an allocation. +func addressToUint32(addr tcpip.Address) uint32 { + _ = addr[3] // bounds check hint to compiler + return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 +} + // hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number, and a random initial -// value (generated once on initialization) to generate the hash. +// destination address, the transport protocol number and a 32-bit number to +// generate the hash. func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - t := r.LocalAddress - a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 - t = r.RemoteAddress - b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + a := addressToUint32(r.LocalAddress) + b := addressToUint32(r.RemoteAddress) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } // NewProtocol returns an IPv4 network protocol. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -590,9 +781,33 @@ func NewProtocol(*stack.Stack) stack.NetworkProtocol { hashIV := r[buckets] return &protocol{ + stack: s, ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), } } + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeaderLength := len(originalIPHeader) + nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + + if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) + } + + flags := originalIPHeader.Flags() + if more { + flags |= header.IPv4FlagMoreFragments + } + nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset)) + nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied)) + nextFragIPHeader.SetChecksum(0) + nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum()) + + return fragPkt, more +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 0b3ed9483..f250a3cde 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -18,17 +18,20 @@ import ( "bytes" "encoding/hex" "math" + "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -92,6 +95,276 @@ func TestExcludeBroadcast(t *testing.T) { }) } +// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and +// checks the response. +func TestIPv4Sanity(t *testing.T) { + const ( + defaultMTU = header.IPv6MinimumMTU + ttl = 255 + nicID = 1 + randomSequence = 123 + randomIdent = 42 + ) + var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + ) + + tests := []struct { + name string + headerLength uint8 // value of 0 means "use correct size" + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + shouldFail bool + expectICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + options []byte + }{ + { + name: "valid", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + }, + // The TTL tests check that we are not rejecting an incoming packet + // with a zero or one TTL, which has been a point of confusion in the + // past as RFC 791 says: "If this field contains the value zero, then the + // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies + // for the case of the destination host, stating as follows. + // + // A host MUST NOT send a datagram with a Time-to-Live (TTL) + // value of zero. + // + // A host MUST NOT discard a datagram just because it was + // received with TTL less than 2. + { + name: "zero TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 0, + shouldFail: false, + }, + { + name: "one TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + shouldFail: false, + }, + { + name: "End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{0, 0, 0, 0}, + }, + { + name: "NOP options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 1, 1}, + }, + { + name: "NOP and End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 0, 0}, + }, + { + name: "bad header length", + headerLength: header.IPv4MinimumSize - 1, + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (0)", + maxTotalLength: 0, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip + icmp - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad protocol", + maxTotalLength: defaultMTU, + transportProtocol: 99, + TTL: ttl, + shouldFail: true, + expectICMP: true, + ICMPType: header.ICMPv4DstUnreachable, + ICMPCode: header.ICMPv4ProtoUnreachable, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + } + + // Default routes for IPv4 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // Round up the header size to the next multiple of 4 as RFC 791, page 11 + // says: "Internet Header Length is the length of the internet header + // in 32 bit words..." and on page 23: "The internet header padding is + // used to ensure that the internet header ends on a 32 bit boundary." + ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1) + + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + + // Specify ident/seq to make sure we get the same in the response. + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + if test.maxTotalLength < totalLen { + totalLen = test.maxTotalLength + } + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHeaderLength), + TotalLength: totalLen, + Protocol: test.transportProtocol, + TTL: test.TTL, + SrcAddr: remoteIPv4Addr, + DstAddr: ipv4Addr.Address, + }) + if n := copy(ip.Options(), test.options); n != len(test.options) { + t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options)) + } + // Override the correct value if the test case specified one. + if test.headerLength != 0 { + ip.SetHeaderLength(test.headerLength) + } + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e.Read() + if !ok { + if test.shouldFail { + if test.expectICMP { + t.Fatal("expected ICMP error response missing") + } + return // Expected silent failure. + } + t.Fatal("expected ICMP echo reply missing") + } + + // Check the route that brought the packet to us. + if reply.Route.LocalAddress != ipv4Addr.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) + } + if reply.Route.RemoteAddress != remoteIPv4Addr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) + } + + // Make sure it's all in one buffer. + vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) + replyIPHeader := header.IPv4(vv.ToView()) + + // At this stage we only know it's an IP header so verify that much. + checker.IPv4(t, replyIPHeader, + checker.SrcAddr(ipv4Addr.Address), + checker.DstAddr(remoteIPv4Addr), + ) + + // All expected responses are ICMP packets. + if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { + t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + } + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) + + // Sanity check the response. + switch replyICMPHeader.Type() { + case header.ICMPv4DstUnreachable: + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + if !test.shouldFail || !test.expectICMP { + t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + return + case header.ICMPv4EchoReply: + checker.IPv4(t, replyIPHeader, + checker.IPv4HeaderLength(ipHeaderLength), + checker.IPv4Options(test.options), + checker.IPFullLength(uint16(requestPkt.Size())), + checker.ICMPv4( + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Seq(randomSequence), + checker.ICMPv4Ident(randomIdent), + checker.ICMPv4Checksum(), + ), + ) + if test.shouldFail { + t.Fatalf("unexpected Echo Reply packet\n") + } + default: + t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + } + }) + } +} + // comparePayloads compared the contents of all the packets against the contents // of the source packet. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { @@ -123,16 +396,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if i == 0 { - got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() - // sourcePacketInfo does not have NetworkHeader added, simulate one. - want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() - // Check that it kept the transport header in packet.TransportHeader if - // it fits in the first fragment. - if want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) - } - } if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } @@ -162,6 +425,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } func TestFragmentation(t *testing.T) { + const ttl = 42 + var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { manyPayloadViewsSizes[i] = 7 @@ -175,15 +440,15 @@ func TestFragmentation(t *testing.T) { payloadViewsSizes []int expectedFrags int }{ - {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, + {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, + {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, + {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, + {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, + {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, + {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, + {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, + {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, } for _, ft := range fragTests { @@ -194,11 +459,11 @@ func TestFragmentation(t *testing.T) { source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, - TTL: 42, + TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Errorf("got err = %s, want = nil", err) + t.Fatalf("r.WritePacket(_, _, _) = %s", err) } if got := len(ep.WrittenPackets); got != ft.expectedFrags { @@ -207,6 +472,9 @@ func TestFragmentation(t *testing.T) { if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } @@ -215,36 +483,70 @@ func TestFragmentation(t *testing.T) { // TestFragmentationErrors checks that errors are returned from write packet // correctly. func TestFragmentationErrors(t *testing.T) { + const ttl = 42 + + expectedError := tcpip.ErrAborted fragTests := []struct { description string mtu uint32 transportHeaderLength int - payloadViewsSizes []int - err *tcpip.Error + payloadSize int allowPackets int + fragmentCount int }{ - {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1}, - {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0}, + { + description: "No frag", + mtu: 2000, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 0, + fragmentCount: 1, + }, + { + description: "Error on first frag", + mtu: 500, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 0, + fragmentCount: 3, + }, + { + description: "Error on second frag", + mtu: 500, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 1, + fragmentCount: 3, + }, + { + description: "Error on first frag MTU smaller than header", + mtu: 500, + transportHeaderLength: 1000, + payloadSize: 500, + allowPackets: 0, + fragmentCount: 4, + }, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, - TTL: 42, + TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.err { - t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) + if err != expectedError { + t.Errorf("got WritePacket() = %s, want = %s", err, expectedError) } if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } + if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + } }) } } @@ -1005,6 +1307,7 @@ func TestReceiveFragments(t *testing.T) { func TestWriteStats(t *testing.T) { const nPackets = 3 + tests := []struct { name string setup func(*testing.T, *stack.Stack) @@ -1040,7 +1343,7 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to find filter table") } ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } @@ -1062,10 +1365,10 @@ func TestWriteStats(t *testing.T) { } // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } @@ -1150,12 +1453,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { dst = "\x10\x00\x00\x02" ) if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) } { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + mask := tcpip.AddressMask(header.IPv4Broadcast) + subnet, err := tcpip.NewSubnet(dst, mask) if err != nil { - t.Fatalf("NewSubnet(_, _) failed: %v", err) + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, @@ -1164,7 +1468,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { } rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) } return rt } diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 8bd8f5c52..a30437f02 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -5,16 +5,20 @@ package(licenses = ["notice"]) go_library( name = "ipv6", srcs = [ + "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "ndp.go", ], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", ], ) @@ -38,6 +42,7 @@ go_test( "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go index d199ded6a..09ba133b1 100644 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go @@ -14,7 +14,7 @@ // Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. -package stack +package ipv6 import "strconv" diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index dd3295b31..629d1818e 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,6 +15,8 @@ package ipv6 import ( + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -39,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match an address we own. src := hdr.SourceAddress() - if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { return } @@ -207,14 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - s := r.Stack() - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now, drop this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // If the target address is tentative and the source of the packet is a // unicast (specified) address, then the source of the packet is // attempting to perform address resolution on the target. In this case, @@ -227,7 +222,20 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // stack know so it can handle such a scenario and do nothing further with // the NS. if r.RemoteAddress == header.IPv6Any { - s.DupTentativeAddrDetected(e.nicID, targetAddr) + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } } // Do not handle neighbor solicitations targeted to an address that is @@ -240,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { + if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -275,7 +283,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } else if e.nud != nil { e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) + } + + // As per RFC 4861 section 7.1.1: + // A node MUST silently discard any received Neighbor Solicitation + // messages that do not satisfy all of the following validity checks: + // ... + // - If the IP source address is the unspecified address, the IP + // destination address is a solicited-node multicast address. + if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + received.Invalid.Increment() + return } // ICMPv6 Neighbor Solicit messages are always sent to @@ -353,20 +372,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // NDP datagrams are very small and ToView() will not incur allocations. na := header.NDPNeighborAdvert(payload.ToView()) targetAddr := na.TargetAddress() - s := r.Stack() - - if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { - // We will only get an error if the NIC is unrecognized, which should not - // happen. For now short-circuit this packet. - // - // TODO(b/141002840): Handle this better? - return - } else if isTentative { + if e.hasTentativeAddr(targetAddr) { // We just got an NA from a node that owns an address we are performing // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - s.DupTentativeAddrDetected(e.nicID, targetAddr) + // + // We would get an error if the address no longer exists or the address + // is no longer tentative (DAD resolved between the call to + // hasTentativeAddr and this point). Both of these are valid scenarios: + // 1) An address may be removed at any time. + // 2) As per RFC 4862 section 5.4, DAD is not a perfect: + // "Note that the method for detecting duplicates + // is not completely reliable, and it is possible that duplicate + // addresses will still exist" + // + // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate + // address is detected for an assigned address. + if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) + } return } @@ -396,7 +421,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // address cache with the link address for the target of the message. if len(targetLinkAddr) != 0 { if e.nud == nil { - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) return } @@ -424,7 +449,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme localAddr = "" } - r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -568,9 +593,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } - // Tell the NIC to handle the RA. - stack := r.Stack() - stack.HandleNDPRA(e.nicID, routerAddr, ra) + e.mu.Lock() + e.mu.ndp.handleRA(routerAddr, ra) + e.mu.Unlock() case header.ICMPv6RedirectMsg: // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating @@ -676,6 +701,36 @@ type icmpReason interface { isICMPReason() } +// icmpReasonParameterProblem is an error during processing of extension headers +// or the fixed header defined in RFC 4443 section 3.4. +type icmpReasonParameterProblem struct { + code header.ICMPv6Code + + // respondToMulticast indicates that we are sending a packet that falls under + // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondToMulticast bool + + // pointer is defined in the RFC 4443 setion 3.4 which reads: + // + // Pointer Identifies the octet offset within the invoking packet + // where the error was detected. + // + // The pointer will point beyond the end of the ICMPv6 + // packet if the field in error is beyond what can fit + // in the maximum size of an ICMPv6 error message. + pointer uint32 +} + +func (*icmpReasonParameterProblem) isICMPReason() {} + // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. type icmpReasonPortUnreachable struct{} @@ -695,7 +750,7 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // - // TODO(b/164522993) There are exceptions to this rule. + // There are exceptions to this rule. // See: point e.3) RFC 4443 section-2.4 // // (e) An ICMPv6 error message MUST NOT be originated as a result of @@ -713,7 +768,12 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // Section 4.2 of [IPv6]) that has the Option Type highest- // order two bits set to 10). // - if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any { + var allowResponseToMulticast bool + if reason, ok := reason.(*icmpReasonParameterProblem); ok { + allowResponseToMulticast = reason.respondToMulticast + } + + if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { return nil } @@ -766,10 +826,21 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - icmpHdr.SetType(header.ICMPv6DstUnreachable) + var counter *tcpip.StatCounter + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + icmpHdr.SetType(header.ICMPv6ParamProblem) + icmpHdr.SetCode(reason.code) + icmpHdr.SetTypeSpecific(reason.pointer) + counter = sent.ParamProblem + case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + counter = sent.DstUnreachable + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) - counter := sent.DstUnreachable err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) if err != nil { sent.Dropped.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index dd58022d6..3a40b790e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -50,6 +50,10 @@ type stubLinkEndpoint struct { stack.LinkEndpoint } +func (*stubLinkEndpoint) MTU() uint32 { + return defaultMTU +} + func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { // Indicate that resolution for link layer addresses is required to send // packets over this link. This is needed so the NIC knows to allocate a @@ -103,6 +107,30 @@ func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Lin func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -150,9 +178,13 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) @@ -288,9 +320,13 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) defer ep.Close() + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) @@ -1204,7 +1240,10 @@ func TestLinkAddressRequest(t *testing.T) { } for _, test := range tests { - p := NewProtocol(nil) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + p := s.NetworkProtocolInstance(ProtocolNumber) linkRes, ok := p.(stack.LinkAddressResolver) if !ok { t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e436c6a9e..73e50f8d6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,14 +16,19 @@ package ipv6 import ( + "encoding/binary" "fmt" + "hash/fnv" + "sort" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -38,16 +43,307 @@ const ( // DefaultTTL is the default hop limit for IPv6 Packets egressed by // Netstack. DefaultTTL = 64 + + // buckets for fragment identifiers + buckets = 2048 ) +var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) +var _ stack.AddressableEndpoint = (*endpoint)(nil) +var _ stack.NetworkEndpoint = (*endpoint)(nil) +var _ stack.NDPEndpoint = (*endpoint)(nil) +var _ NDPEndpoint = (*endpoint)(nil) + type endpoint struct { - nicID tcpip.NICID + nic stack.NetworkInterface linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack + + // enabled is set to 1 when the endpoint is enabled and 0 when it is + // disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + + mu struct { + sync.RWMutex + + addressableEndpointState stack.AddressableEndpointState + ndp ndpState + } +} + +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it is passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte +} + +// InvalidateDefaultRouter implements stack.NDPEndpoint. +func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.invalidateDefaultRouter(rtr) +} + +// SetNDPConfigurations implements NDPEndpoint. +func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) { + c.validate() + e.mu.Lock() + defer e.mu.Unlock() + e.mu.ndp.configs = c +} + +// hasTentativeAddr returns true if addr is tentative on e. +func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { + e.mu.RLock() + addressEndpoint := e.getAddressRLocked(addr) + e.mu.RUnlock() + return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative +} + +// dupTentativeAddrDetected attempts to inform e that a tentative addr is a +// duplicate on a link. +// +// dupTentativeAddrDetected removes the tentative address if it exists. If the +// address was generated via SLAAC, an attempt is made to generate a new +// address. +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return tcpip.ErrBadAddress + } + + if addressEndpoint.GetKind() != stack.PermanentTentative { + return tcpip.ErrInvalidEndpointState + } + + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an + // attempt will be made to generate a new address for it. + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + return err + } + + prefix := addressEndpoint.AddressWithPrefix().Subnet() + + switch t := addressEndpoint.ConfigType(); t { + case stack.AddressConfigStatic: + case stack.AddressConfigSlaac: + e.mu.ndp.regenerateSLAACAddr(prefix) + case stack.AddressConfigSlaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) + default: + panic(fmt.Sprintf("unrecognized address config type = %d", t)) + } + + return nil +} + +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.Enabled() { + return + } + + if forwarding { + // When transitioning into an IPv6 router, host-only state (NDP discovered + // routers, discovered on-link prefixes, and auto-generated addresses) is + // cleaned up/invalidated and NDP router solicitations are stopped. + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(true /* hostOnly */) + } else { + // When transitioning into an IPv6 host, NDP router solicitations are + // started. + e.mu.ndp.startSolicitingRouters() + } +} + +// Enable implements stack.NetworkEndpoint. +func (e *endpoint) Enable() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If the NIC is not enabled, the endpoint can't do anything meaningful so + // don't enable the endpoint. + if !e.nic.Enabled() { + return tcpip.ErrNotPermitted + } + + // If the endpoint is already enabled, there is nothing for it to do. + if !e.setEnabled(true) { + return nil + } + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + // + // Also auto-generate an IPv6 link-local address based on the endpoint's + // link address if it is configured to do so. Note, each interface is + // required to have IPv6 link-local unicast address, as per RFC 4291 + // section 2.1. + + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent + // state. + // + // Addresses may have aleady completed DAD but in the time since the endpoint + // was last enabled, other devices may have acquired the same addresses. + var err *tcpip.Error + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + addr := addressEndpoint.AddressWithPrefix().Address + if !header.IsV6UnicastAddress(addr) { + return true + } + + switch addressEndpoint.GetKind() { + case stack.Permanent: + addressEndpoint.SetKind(stack.PermanentTentative) + fallthrough + case stack.PermanentTentative: + err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint) + return err == nil + default: + return true + } + }) + if err != nil { + return err + } + + // Do not auto-generate an IPv6 link-local address for loopback devices. + if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + // The valid and preferred lifetime is infinite for the auto-generated + // link-local address. + e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) + } + + // If we are operating as a router, then do not solicit routers since we + // won't process the RAs anyway. + // + // Routers do not process Router Advertisements (RA) the same way a host + // does. That is, routers do not learn from RAs (e.g. on-link prefixes + // and default routers). Therefore, soliciting RAs from other routers on + // a link is unnecessary for routers. + if !e.protocol.Forwarding() { + e.mu.ndp.startSolicitingRouters() + } + + return nil +} + +// Enabled implements stack.NetworkEndpoint. +func (e *endpoint) Enabled() bool { + return e.nic.Enabled() && e.isEnabled() +} + +// isEnabled returns true if the endpoint is enabled, regardless of the +// enabled status of the NIC. +func (e *endpoint) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) == 1 +} + +// setEnabled sets the enabled status for the endpoint. +// +// Returns true if the enabled status was updated. +func (e *endpoint) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&e.enabled, 1) == 0 + } + return atomic.SwapUint32(&e.enabled, 0) == 1 +} + +// Disable implements stack.NetworkEndpoint. +func (e *endpoint) Disable() { + e.mu.Lock() + defer e.mu.Unlock() + e.disableLocked() +} + +func (e *endpoint) disableLocked() { + if !e.setEnabled(false) { + return + } + + e.mu.ndp.stopSolicitingRouters() + e.mu.ndp.cleanupState(false /* hostOnly */) + e.stopDADForPermanentAddressesLocked() + + // The endpoint may have already left the multicast group. + if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) + } +} + +// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) stopDADForPermanentAddressesLocked() { + // Stop DAD for all the tentative unicast addresses. + e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.GetKind() != stack.PermanentTentative { + return true + } + + addr := addressEndpoint.AddressWithPrefix().Address + if header.IsV6UnicastAddress(addr) { + e.mu.ndp.stopDuplicateAddressDetection(addr) + } + + return true + }) } // DefaultTTL is the default hop limit for this endpoint. @@ -61,16 +357,6 @@ func (e *endpoint) MTU() uint32 { return calculateMTU(e.linkEP.MTU()) } -// NICID returns the ID of the NIC this endpoint belongs to. -func (e *endpoint) NICID() tcpip.NICID { - return e.nicID -} - -// Capabilities implements stack.NetworkEndpoint.Capabilities. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { @@ -96,7 +382,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + pkt.NetworkProtocolNumber = ProtocolNumber +} + +func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { + return pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) +} + +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. The mtu is the maximum size of the packets. The transport +// header protocol number is required to avoid parsing the IPv6 extension +// headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + if fragMTU < pkt.TransportHeader().View().Size() { + // As per RFC 8200 Section 4.5, the Transport Header is expected to be small + // enough to fit in the first fragment. + return 0, 1, tcpip.ErrMessageTooLong + } + + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + break + } + } + + return n, 0, nil } // WritePacket writes a packet to the given destination address and protocol. @@ -105,8 +428,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. - nicName := e.stack.FindNICNameFromID(e.NICID()) - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() @@ -122,7 +445,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) return nil @@ -143,14 +466,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } + if e.packetMustBeFragmented(pkt, gso) { + sent, remain, err := e.handleFragments(r, gso, e.linkEP.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + return err + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() return err } + r.Stats().IP.PacketsSent.Increment() return nil } -// WritePackets implements stack.LinkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") @@ -161,18 +499,38 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) + if e.packetMustBeFragmented(pb, gso) { + current := pb + _, _, err := e.handleFragments(r, gso, e.linkEP.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(current, fragPkt) + current = current.Next() + return nil + }) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + // The fragmented packet can be released. The rest of the packets can be + // processed. + pkts.Remove(pb) + pb = current + } } // iptables filtering. All packets that reach here are locally // generated. - nicName := e.stack.FindNICNameFromID(e.NICID()) - ipt := e.stack.IPTables() + nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + ipt := e.protocol.stack.IPTables() dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) @@ -186,7 +544,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -197,6 +555,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -219,12 +578,24 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { + if !e.isEnabled() { + return + } + h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + // As per RFC 4291 section 2.7: + // Multicast addresses must not be used as source addresses in IPv6 + // packets or appear in any Routing header. + if header.IsV6MulticastAddress(r.RemoteAddress) { + r.Stats().IP.InvalidSourceAddressesReceived.Increment() + return + } + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. @@ -236,15 +607,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { hasFragmentHeader := false // iptables filtering. All packets that reach here are intended for - // this machine and will not be forwarded. - ipt := e.stack.IPTables() + // this machine and need not be forwarded. + ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesInputDropped.Increment() return } - for firstHeader := true; ; firstHeader = false { + for { + // Keep track of the start of the previous header so we can report the + // special case of a Hop by Hop at a location other than at the start. + previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -258,11 +632,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 - // (unrecognized next header) error in response to an extension header's - // Next Header field with the Hop By Hop extension header identifier. - if !firstHeader { + if previousHeaderStart != 0 { + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + }, pkt) return } @@ -284,13 +658,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) @@ -301,16 +687,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.4, if a node encounters a routing header with // an unrecognized routing type value, with a non-zero Segments Left // value, the node must discard the packet and send an ICMP Parameter - // Problem, Code 0. If the Segments Left is 0, the node must ignore the - // Routing extension header and process the next header in the packet. + // Problem, Code 0 to the packet's Source Address, pointing to the + // unrecognized Routing Type. + // + // If the Segments Left is 0, the node must ignore the Routing extension + // header and process the next header in the packet. // // Note, the stack does not yet handle any type of routing extension // header, so we just make sure Segments Left is zero before processing // the next extension header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for - // unrecognized routing types with a non-zero Segments Left value. if extHdr.SegmentsLeft() != 0 { + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: it.ParseOffset(), + }, pkt) return } @@ -445,13 +835,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) @@ -470,13 +872,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf + r.Stats().IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: @@ -486,17 +887,40 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC 8200 section 4. (page 7): + // Extension headers are numbered from IANA IP Protocol Numbers + // [IANA-PN], the same values used for IPv4 and IPv6. When + // processing a sequence of Next Header values in a packet, the + // first one that is not an extension header [IANA-EH] indicates + // that the next item in the packet is the corresponding upper-layer + // header. + // With more related information on page 8: + // If, as a result of processing a header, the destination node is + // required to proceed to the next header but the Next Header value + // in the current header is unrecognized by the node, it should + // discard the packet and send an ICMP Parameter Problem message to + // the source of the packet, with an ICMP Code value of 1 + // ("unrecognized Next Header type encountered") and the ICMP + // Pointer field containing the offset of the unrecognized value + // within the original packet. + // + // Which when taken together indicate that an unknown protocol should + // be treated as an unrecognized next header value. + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } } default: - // If we receive a packet for an extension header we do not yet handle, - // drop the packet for now. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + _ = returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) r.Stats().UnknownProtocolRcvdPackets.Increment() return } @@ -504,19 +928,343 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } // Close cleans up resources associated with the endpoint. -func (*endpoint) Close() {} +func (e *endpoint) Close() { + e.mu.Lock() + e.disableLocked() + e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */) + e.stopDADForPermanentAddressesLocked() + e.mu.addressableEndpointState.Cleanup() + e.mu.Unlock() + + e.protocol.forgetEndpoint(e) +} // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } +// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + // TODO(b/169350103): add checks here after making sure we no longer receive + // an empty address. + e.mu.Lock() + defer e.mu.Unlock() + return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) +} + +// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but +// with locking requirements. +// +// addAndAcquirePermanentAddressLocked also joins the passed address's +// solicited-node multicast group and start duplicate address detection. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { + addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err != nil { + return nil, err + } + + if !header.IsV6UnicastAddress(addr.Address) { + return addressEndpoint, nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { + return nil, err + } + + addressEndpoint.SetKind(stack.PermanentTentative) + + if e.Enabled() { + if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil { + return nil, err + } + } + + return addressEndpoint, nil +} + +// RemovePermanentAddress implements stack.AddressableEndpoint. +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + return e.removePermanentEndpointLocked(addressEndpoint, true) +} + +// removePermanentEndpointLocked is like removePermanentAddressLocked except +// it works with a stack.AddressEndpoint. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error { + addr := addressEndpoint.AddressWithPrefix() + unicast := header.IsV6UnicastAddress(addr.Address) + if unicast { + e.mu.ndp.stopDuplicateAddressDetection(addr.Address) + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + switch addressEndpoint.ConfigType() { + case stack.AddressConfigSlaac: + e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case stack.AddressConfigSlaacTemp: + e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + } + } + + if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil { + return err + } + + if !unicast { + return nil + } + + snmc := header.SolicitedNodeAddr(addr.Address) + if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + + return nil +} + +// hasPermanentAddressLocked returns true if the endpoint has a permanent +// address equal to the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { + addressEndpoint := e.getAddressRLocked(addr) + if addressEndpoint == nil { + return false + } + return addressEndpoint.GetKind().IsPermanent() +} + +// getAddressRLocked returns the endpoint for the passed address. +// +// Precondition: e.mu must be read or write locked. +func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { + return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) +} + +// MainAddress implements stack.AddressableEndpoint. +func (e *endpoint) MainAddress() tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.MainAddress() +} + +// AcquireAssignedAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + e.mu.Lock() + defer e.mu.Unlock() + return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB) +} + +// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with +// locking requirements. +// +// Precondition: e.mu must be write locked. +func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint { + return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB) +} + +// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. +func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + e.mu.RLock() + defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { + // addrCandidate is a candidate for Source Address Selection, as per + // RFC 6724 section 5. + type addrCandidate struct { + addressEndpoint stack.AddressEndpoint + scope header.IPv6AddressScope + } + + if len(remoteAddr) == 0 { + return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + var cs []addrCandidate + e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !addressEndpoint.IsAssigned(allowExpired) { + return + } + + addr := addressEndpoint.AddressWithPrefix().Address + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) + } + + cs = append(cs, addrCandidate{ + addressEndpoint: addressEndpoint, + scope: scope, + }) + }) + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return true + } + if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { + return saTemp + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if c.addressEndpoint.IncRef() { + return c.addressEndpoint + } + } + + return nil +} + +// PrimaryAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PrimaryAddresses() +} + +// PermanentAddresses implements stack.AddressableEndpoint. +func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.PermanentAddresses() +} + +// JoinGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + if !header.IsV6MulticastAddress(addr) { + return false, tcpip.ErrBadAddress + } + + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.JoinGroup(addr) +} + +// LeaveGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.addressableEndpointState.LeaveGroup(addr) +} + +// IsInGroup implements stack.GroupAddressableEndpoint. +func (e *endpoint) IsInGroup(addr tcpip.Address) bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.addressableEndpointState.IsInGroup(addr) +} + +var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) +var _ stack.NetworkProtocol = (*protocol)(nil) + type protocol struct { + stack *stack.Stack + + mu struct { + sync.RWMutex + + eps map[*endpoint]struct{} + } + + ids []uint32 + hashIV uint32 + // defaultTTL is the current default TTL for the protocol. Only the - // uint8 portion of it is meaningful and it must be accessed - // atomically. - defaultTTL uint32 + // uint8 portion of it is meaningful. + // + // Must be accessed using atomic operations. + defaultTTL uint32 + + // forwarding is set to 1 when the protocol has forwarding enabled and 0 + // when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + fragmentation *fragmentation.Fragmentation + + // ndpDisp is the NDP event dispatcher that is used to send the netstack + // integrator NDP related events. + ndpDisp NDPDispatcher + + // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. + ndpConfigs NDPConfigurations + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + + // autoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. + autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -541,16 +1289,36 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { - return &endpoint{ - nicID: nicID, - linkEP: linkEP, +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &endpoint{ + nic: nic, + linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, dispatcher: dispatcher, protocol: p, - stack: st, } + e.mu.addressableEndpointState.Init(e) + e.mu.ndp = ndpState{ + ep: e, + configs: p.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), + } + e.mu.ndp.initializeTempAddrState() + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.eps[e] = struct{}{} + return e +} + +func (p *protocol) forgetEndpoint(e *endpoint) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, e) } // SetOption implements NetworkProtocol.SetOption. @@ -601,6 +1369,35 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) Forwarding() bool { + return uint8(atomic.LoadUint32(&p.forwarding)) == 1 +} + +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.SwapUint32(&p.forwarding, 1) == 0 + } + return atomic.SwapUint32(&p.forwarding, 0) == 1 +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (p *protocol) SetForwarding(v bool) { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.setForwarding(v) { + return + } + + for ep := range p.mu.eps { + ep.transitionForwarding(v) + } +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { @@ -611,10 +1408,144 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -// NewProtocol returns an IPv6 network protocol. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { - return &protocol{ - defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), +// Options holds options to configure a new protocol. +type Options struct { + // NDPConfigs is the default NDP configurations used by interfaces. + NDPConfigs NDPConfigurations + + // AutoGenIPv6LinkLocal determines whether or not the stack attempts to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // + // Note, setting this to true does not mean that a link-local address is + // assigned right away, or at all. If Duplicate Address Detection is enabled, + // an address is only assigned if it successfully resolves. If it fails, no + // further attempts are made to auto-generate an IPv6 link-local adddress. + // + // The generated link-local address follows RFC 4291 Appendix A guidelines. + AutoGenIPv6LinkLocal bool + + // NDPDisp is the NDP event dispatcher that an integrator can provide to + // receive NDP related events. + NDPDisp NDPDispatcher + + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte +} + +// NewProtocolWithOptions returns an IPv6 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { + opts.NDPConfigs.validate() + + ids := hash.RandN32(buckets) + hashIV := hash.RandN32(1)[0] + + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + ids: ids, + hashIV: hashIV, + + ndpDisp: opts.NDPDisp, + ndpConfigs: opts.NDPConfigs, + opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, + autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + } + p.mu.eps = make(map[*endpoint]struct{}) + p.SetDefaultTTL(DefaultTTL) + return p + } +} + +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) +} + +// calculateFragmentInnerMTU calculates the maximum number of bytes of +// fragmentable data a fragment can have, based on the link layer mtu and pkt's +// network header size. +func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // MTU because they should only be transmitted once. + mtu -= uint32(pkt.NetworkHeader().View().Size()) + mtu -= header.IPv6FragmentHeaderSize + // Round the MTU down to align to 8 bytes. + mtu &^= 7 + if mtu <= maxPayloadSize { + return mtu + } + return maxPayloadSize +} + +func calculateFragmentReserve(pkt *stack.PacketBuffer) int { + return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address and 32-bit number to generate the hash. +func hashRoute(r *stack.Route, hashIV uint32) uint32 { + // The FNV-1a was chosen because it is a fast hashing algorithm, and + // cryptographic properties are not needed here. + h := fnv.New32a() + if _, err := h.Write([]byte(r.LocalAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) } + + if _, err := h.Write([]byte(r.RemoteAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) + } + + s := make([]byte, 4) + binary.LittleEndian.PutUint32(s, hashIV) + if _, err := h.Write(s); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err)) + } + + return h.Sum32() +} + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeadersLength := len(originalIPHeaders) + fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + + // Copy the IPv6 header and any extension headers already populated. + if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) + } + fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) + + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) + fragmentHeader.Encode(&header.IPv6FragmentFields{ + M: more, + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + Identification: id, + NextHeader: uint8(transportProto), + }) + + return fragPkt, more } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 8ae146c5e..e792ca9e2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,17 +15,21 @@ package ipv6 import ( + "encoding/hex" + "fmt" "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -136,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst } } +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { + // sourcePacket does not have its IP Header populated. Let's copy the one + // from the first fragment. + source := header.IPv6(packets[0].NetworkHeader().View()) + sourceIPHeadersLen := len(source) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) + source = append(source, vv.ToView()...) + + var reassembledPayload buffer.VectorisedView + for i, fragment := range packets { + // Confirm that the packet is valid. + allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views()) + fragmentIPHeaders := header.IPv6(allBytes.ToView()) + if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders)) + } + + fragmentIPHeadersLength := fragment.NetworkHeader().View().Size() + if fragmentIPHeadersLength != sourceIPHeadersLen { + return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen) + } + + if got := len(fragmentIPHeaders); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu) + } + + sourceIPHeader := source[:header.IPv6MinimumSize] + fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize] + + if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize { + return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize) + } + + // We expect the IPv6 Header to be similar across each fragment, besides the + // payload length. + sourceIPHeader.SetPayloadLength(0) + fragmentIPHeader.SetPayloadLength(0) + if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) + } + + if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) + } + + if len(packets) > 1 { + // If the source packet was big enough that it needed fragmentation, let's + // inspect the fragment header. Because no other extension headers are + // supported, it will always be the last extension header. + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength]) + + if got := fragmentHeader.More(); got != wantFragments[i].more { + return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more) + } + if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset { + return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset) + } + if got := fragmentHeader.NextHeader(); got != uint8(proto) { + return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto)) + } + } + + // Store the reassembled payload as we parse each fragment. The payload + // includes the Transport header and everything after. + reassembledPayload.AppendView(fragment.TransportHeader().View()) + reassembledPayload.Append(fragment.Data) + } + + result := reassembledPayload.ToView() + if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + + return nil +} + // TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and // UDP packets destined to the IPv6 link-local all-nodes multicast address. func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { @@ -170,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // packets destined to the IPv6 solicited-node address of an assigned IPv6 // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - const nicID = 1 - tests := []struct { name string protocolFactory stack.TransportProtocolFactory @@ -195,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv6EmptySubnet, NIC: nicID, }, @@ -295,17 +373,22 @@ func TestAddIpv6Address(t *testing.T) { } func TestReceiveIPv6ExtHdrs(t *testing.T) { - const nicID = 1 - tests := []struct { name string extHdr func(nextHdr uint8) ([]byte, uint8) shouldAccept bool + // Should we expect an ICMP response and if so, with what contents? + expectICMP bool + ICMPType header.ICMPv6Type + ICMPCode header.ICMPv6Code + pointer uint32 + multicast bool }{ { name: "None", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, shouldAccept: true, + expectICMP: false, }, { name: "hopbyhop with unknown option skippable action", @@ -336,9 +419,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { - name: "hopbyhop with unknown option discard and send icmp action", + name: "hopbyhop with unknown option discard and send icmp action (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -348,12 +432,58 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + name: "hopbyhop with unknown option discard and send icmp action (multicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + multicast: true, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -364,39 +494,77 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: false, }, { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: true, }, { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 1, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6ErroneousHeader, + pointer: header.IPv6FixedHeaderSize + 2, }, { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 0, 0, 0, 0, + }, fragmentExtHdrID + }, shouldAccept: true, }, { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: true, + expectICMP: false, }, { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: false, + expectICMP: false, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{}, + noNextHdrID + }, shouldAccept: false, + expectICMP: false, }, { name: "destination with unknown option skippable action", @@ -412,6 +580,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: true, + expectICMP: false, }, { name: "destination with unknown option discard action", @@ -427,9 +596,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: false, + expectICMP: false, }, { - name: "destination with unknown option discard and send icmp action", + name: "destination with unknown option discard and send icmp action (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -439,12 +609,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action unless multicast dest", + name: "destination with unknown option discard and send icmp action (muilticast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. + }, destinationExtHdrID + }, + multicast: true, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -455,22 +651,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "routing - atomic fragment", + name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + nextHdr, 1, - // Fragment extension header. - nextHdr, 0, 0, 0, 1, 2, 3, 4, - }, routingExtHdrID + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. + }, destinationExtHdrID }, - shouldAccept: true, + shouldAccept: false, + expectICMP: false, + multicast: true, }, { name: "atomic fragment - routing", @@ -504,12 +711,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { return []byte{ // Routing extension header. hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. // Hop By Hop extension header with skippable unknown option. nextHdr, 0, 62, 4, 1, 2, 3, 4, }, routingExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { + name: "routing - hop by hop (with send icmp unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. + + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, routingExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, }, { name: "No next header", @@ -553,6 +790,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", @@ -573,6 +811,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, } @@ -582,7 +821,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -590,6 +829,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + // Add a default route so that a return packet knows where to go. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -631,12 +878,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: ipv6NextHdr, HopLimit: 255, SrcAddr: addr1, - DstAddr: addr2, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -650,6 +901,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = 0", got) } + if !test.expectICMP { + if p, ok := e.Read(); ok { + t.Fatalf("unexpected packet received: %#v", p) + } + return + } + + // ICMP required. + p, ok := e.Read() + if !ok { + t.Fatalf("expected packet wasn't written out") + } + + // Pack the output packet into a single buffer.View as the checkers + // assume that. + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { + t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) + } + + ipHdr := header.IPv6(pkt) + checker.IPv6(t, ipHdr, checker.ICMPv6( + checker.ICMPv6Type(test.ICMPType), + checker.ICMPv6Code(test.ICMPCode))) + + // We know we are looking at no extension headers in the error ICMP + // packets. + icmpPkt := header.ICMPv6(ipHdr.Payload()) + // We know we sent small packets that won't be truncated when reflected + // back to us. + originalPacket := icmpPkt.Payload() + if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { + t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) + } + if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { + t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) + } return } @@ -683,7 +972,6 @@ type fragmentData struct { func TestReceiveIPv6Fragments(t *testing.T) { const ( - nicID = 1 udpPayload1Length = 256 udpPayload2Length = 128 // Used to test cases where the fragment blocks are not a multiple of @@ -1748,7 +2036,7 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to find filter table") } ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } @@ -1770,10 +2058,10 @@ func TestWriteStats(t *testing.T) { } // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] - filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. - filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } @@ -1815,7 +2103,6 @@ func TestWriteStats(t *testing.T) { t.Run(test.name, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) - var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1857,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err) + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) } { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") + subnet, err := tcpip.NewSubnet(dst, mask) if err != nil { - t.Fatalf("NewSubnet(_, _) failed: %v", err) + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, @@ -1871,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { } rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err) } return rt } @@ -1895,3 +2183,320 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, lm.limit-- return false, false } + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if !hasEP { + t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep) + } + } + + ep.Close() + + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[ep] + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep) + } + } +} + +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 +} + +type fragmentationTestCase struct { + description string + mtu uint32 + gso *stack.GSO + transHdrLen int + extraHdrLen int + payloadSize int + wantFragments []fragmentInfo + expectedFrags int +} + +var fragmentationTests = []fragmentationTestCase{ + { + description: "No Fragmentation", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: &stack.GSO{}, + transHdrLen: 100, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso nil", + mtu: 1280, + gso: nil, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 176, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 100, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 76, more: false}, + }, + }, + { + description: "Fragmented with big header and prependable bytes", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 20, + extraHdrLen: header.IPv6MinimumSize + 66, + payloadSize: 1500, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 296, more: false}, + }, + }, +} + +func TestFragmentation(t *testing.T) { + const ( + ttl = 42 + tos = stack.DefaultTOS + transportProto = tcp.ProtocolNumber + ) + + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt.Clone() + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != nil { + t.Fatalf("WritePacket(_, _, _): = %s", err) + } + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + if len(ep.WrittenPackets) > 0 { + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + } + }) + } +} + +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + tests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, + } + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if n != wantTotalPackets || err != nil { + t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } + }) + } +} + +// TestFragmentationErrors checks that errors are returned from WritePacket +// correctly. +func TestFragmentationErrors(t *testing.T) { + const ttl = 42 + + tests := []struct { + description string + mtu uint32 + transHdrLen int + payloadSize int + allowPackets int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error + }{ + { + description: "No frag", + mtu: 2000, + payloadSize: 1000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on first frag", + mtu: 1300, + payloadSize: 3000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on second frag", + mtu: 1500, + payloadSize: 4000, + transHdrLen: 0, + allowPackets: 1, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on packet with MTU smaller than transport header", + mtu: 1280, + transHdrLen: 1500, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrMessageTooLong, + }, + } + + for _, ft := range tests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) + } + }) + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 97ca00d16..48a4c65e3 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack +package ipv6 import ( "fmt" @@ -23,9 +23,27 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + // defaultDupAddrDetectTransmits is the default number of NDP Neighbor // Solicitation messages to send when doing Duplicate Address Detection // for a tentative address. @@ -34,7 +52,7 @@ const ( defaultDupAddrDetectTransmits = 1 // defaultMaxRtrSolicitations is the default number of Router - // Solicitation messages to send when a NIC becomes enabled. + // Solicitation messages to send when an IPv6 endpoint becomes enabled. // // Default = 3 (from RFC 4861 section 10). defaultMaxRtrSolicitations = 3 @@ -131,7 +149,7 @@ const ( minRegenAdvanceDuration = time.Duration(0) // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt - // SLAAC address regenerations in response to a NIC-local conflict. + // SLAAC address regenerations in response to an IPv6 endpoint-local conflict. maxSLAACAddrLocalRegenAttempts = 10 ) @@ -163,7 +181,7 @@ var ( // This is exported as a variable (instead of a constant) so tests // can update it to a smaller value. // - // This value guarantees that a temporary address will be preferred for at + // This value guarantees that a temporary address is preferred for at // least 1hr if the SLAAC prefix is valid for at least that time. MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour @@ -173,11 +191,17 @@ var ( // This is exported as a variable (instead of a constant) so tests // can update it to a smaller value. // - // This value guarantees that a temporary address will be valid for at least + // This value guarantees that a temporary address is valid for at least // 2hrs if the SLAAC prefix is valid for at least that time. MinMaxTempAddrValidLifetime = 2 * time.Hour ) +// NDPEndpoint is an endpoint that supports NDP. +type NDPEndpoint interface { + // SetNDPConfigurations sets the NDP configurations. + SetNDPConfigurations(NDPConfigurations) +} + // DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an // NDP Router Advertisement informed the Stack about. type DHCPv6ConfigurationFromNDPRA int @@ -192,7 +216,7 @@ const ( // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. // // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 - // will return all available configuration information. + // returns all available configuration information when serving addresses. DHCPv6ManagedAddress // DHCPv6OtherConfigurations indicates that other configuration information is @@ -207,19 +231,18 @@ const ( // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { - // OnDuplicateAddressDetectionStatus will be called when the DAD process - // for an address (addr) on a NIC (with ID nicID) completes. resolved - // will be set to true if DAD completed successfully (no duplicate addr - // detected); false otherwise (addr was detected to be a duplicate on - // the link the NIC is a part of, or it was stopped for some other - // reason, such as the address being removed). If an error occured - // during DAD, err will be set and resolved must be ignored. + // OnDuplicateAddressDetectionStatus is called when the DAD process for an + // address (addr) on a NIC (with ID nicID) completes. resolved is set to true + // if DAD completed successfully (no duplicate addr detected); false otherwise + // (addr was detected to be a duplicate on the link the NIC is a part of, or + // it was stopped for some other reason, such as the address being removed). + // If an error occured during DAD, err is set and resolved must be ignored. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) - // OnDefaultRouterDiscovered will be called when a new default router is + // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered // router should be remembered. // @@ -227,56 +250,55 @@ type NDPDispatcher interface { // is also not permitted to call into the stack. OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool - // OnDefaultRouterInvalidated will be called when a discovered default - // router that was remembered is invalidated. + // OnDefaultRouterInvalidated is called when a discovered default router that + // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) - // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is - // discovered. Implementations must return true if the newly discovered - // on-link prefix should be remembered. + // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. + // Implementations must return true if the newly discovered on-link prefix + // should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool - // OnOnLinkPrefixInvalidated will be called when a discovered on-link - // prefix that was remembered is invalidated. + // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that + // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) - // OnAutoGenAddress will be called when a new prefix with its - // autonomous address-configuration flag set has been received and SLAAC - // has been performed. Implementations may prevent the stack from - // assigning the address to the NIC by returning false. + // OnAutoGenAddress is called when a new prefix with its autonomous address- + // configuration flag set is received and SLAAC was performed. Implementations + // may prevent the stack from assigning the address to the NIC by returning + // false. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool - // OnAutoGenAddressDeprecated will be called when an auto-generated - // address (as part of SLAAC) has been deprecated, but is still - // considered valid. Note, if an address is invalidated at the same - // time it is deprecated, the deprecation event MAY be omitted. + // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) + // is deprecated, but is still considered valid. Note, if an address is + // invalidated at the same ime it is deprecated, the deprecation event may not + // be received. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) - // OnAutoGenAddressInvalidated will be called when an auto-generated - // address (as part of SLAAC) has been invalidated. + // OnAutoGenAddressInvalidated is called when an auto-generated address + // (SLAAC) is invalidated. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) - // OnRecursiveDNSServerOption will be called when an NDP option with - // recursive DNS servers has been received. Note, addrs may contain - // link-local addresses. + // OnRecursiveDNSServerOption is called when the stack learns of DNS servers + // through NDP. Note, the addresses may contain link-local addresses. // // It is up to the caller to use the DNS Servers only for their valid // lifetime. OnRecursiveDNSServerOption may be called for new or @@ -288,8 +310,8 @@ type NDPDispatcher interface { // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) - // OnDNSSearchListOption will be called when an NDP option with a DNS - // search list has been received. + // OnDNSSearchListOption is called when the stack learns of DNS search lists + // through NDP. // // It is up to the caller to use the domain names in the search list // for only their valid lifetime. OnDNSSearchListOption may be called @@ -298,8 +320,8 @@ type NDPDispatcher interface { // be increased, decreased or completely invalidated when lifetime = 0. OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) - // OnDHCPv6Configuration will be called with an updated configuration that is - // available via DHCPv6 for a specified NIC. + // OnDHCPv6Configuration is called with an updated configuration that is + // available via DHCPv6 for the passed NIC. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. @@ -320,7 +342,7 @@ type NDPConfigurations struct { // Must be greater than or equal to 1ms. RetransmitTimer time.Duration - // The number of Router Solicitation messages to send when the NIC + // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. MaxRtrSolicitations uint8 @@ -335,24 +357,22 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements will be - // processed. + // HandleRAs determines whether or not Router Advertisements are processed. HandleRAs bool - // DiscoverDefaultRouters determines whether or not default routers will - // be discovered from Router Advertisements. This configuration is - // ignored if HandleRAs is false. + // DiscoverDefaultRouters determines whether or not default routers are + // discovered from Router Advertisements, as per RFC 4861 section 6. This + // configuration is ignored if HandleRAs is false. DiscoverDefaultRouters bool - // DiscoverOnLinkPrefixes determines whether or not on-link prefixes - // will be discovered from Router Advertisements' Prefix Information - // option. This configuration is ignored if HandleRAs is false. + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are + // discovered from Router Advertisements' Prefix Information option, as per + // RFC 4861 section 6. This configuration is ignored if HandleRAs is false. DiscoverOnLinkPrefixes bool - // AutoGenGlobalAddresses determines whether or not global IPv6 - // addresses will be generated for a NIC in response to receiving a new - // Prefix Information option with its Autonomous Address - // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs + // SLAAC to auto-generate global SLAAC addresses in response to Prefix + // Information options, as per RFC 4862. // // Note, if an address was already generated for some unique prefix, as // part of SLAAC, this option does not affect whether or not the @@ -366,12 +386,12 @@ type NDPConfigurations struct { // // If the method used to generate the address does not support creating // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's - // MAC address), then no attempt will be made to resolve the conflict. + // MAC address), then no attempt is made to resolve the conflict. AutoGenAddressConflictRetries uint8 // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC - // addresses will be generated for a NIC as part of SLAAC privacy extensions, - // RFC 4941. + // addresses are generated for an IPv6 endpoint as part of SLAAC privacy + // extensions, as per RFC 4941. // // Ignored if AutoGenGlobalAddresses is false. AutoGenTempGlobalAddresses bool @@ -410,7 +430,7 @@ func DefaultNDPConfigurations() NDPConfigurations { } // validate modifies an NDPConfigurations with valid values. If invalid values -// are present in c, the corresponding default values will be used instead. +// are present in c, the corresponding default values are used instead. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer @@ -439,8 +459,8 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { - // The NIC this ndpState is for. - nic *NIC + // The IPv6 endpoint this ndpState is for. + ep *endpoint // configs is the per-interface NDP configurations. configs NDPConfigurations @@ -458,8 +478,8 @@ type ndpState struct { // Used to let the Router Solicitation timer know that it has been stopped. // // Must only be read from or written to while protected by the lock of - // the NIC this ndpState is associated with. MUST be set when the timer is - // set. + // the IPv6 endpoint this ndpState is associated with. MUST be set when the + // timer is set. done *bool } @@ -492,7 +512,7 @@ type dadState struct { // Used to let the DAD timer know that it has been stopped. // // Must only be read from or written to while protected by the lock of - // the NIC this dadState is associated with. + // the IPv6 endpoint this dadState is associated with. done *bool } @@ -537,7 +557,7 @@ type tempSLAACAddrState struct { // The address's endpoint. // // Must not be nil. - ref *referencedNetworkEndpoint + addressEndpoint stack.AddressEndpoint // Has a new temporary SLAAC address already been regenerated? regenerated bool @@ -567,10 +587,10 @@ type slaacPrefixState struct { // // May only be nil when the address is being (re-)generated. Otherwise, // must not be nil as all SLAAC prefixes must have a stable address. - ref *referencedNetworkEndpoint + addressEndpoint stack.AddressEndpoint - // The number of times an address has been generated locally where the NIC - // already had the generated address. + // The number of times an address has been generated locally where the IPv6 + // endpoint already had the generated address. localGenerationFailures uint8 } @@ -578,11 +598,12 @@ type slaacPrefixState struct { tempAddrs map[tcpip.Address]tempSLAACAddrState // The next two fields are used by both stable and temporary addresses - // generated for a SLAAC prefix. This is safe as only 1 address will be - // in the generation and DAD process at any time. That is, no two addresses - // will be generated at the same time for a given SLAAC prefix. + // generated for a SLAAC prefix. This is safe as only 1 address is in the + // generation and DAD process at any time. That is, no two addresses are + // generated at the same time for a given SLAAC prefix. - // The number of times an address has been generated and added to the NIC. + // The number of times an address has been generated and added to the IPv6 + // endpoint. // // Addresses may be regenerated in reseponse to a DAD conflicts. generationAttempts uint8 @@ -597,16 +618,16 @@ type slaacPrefixState struct { // This function must only be called by IPv6 addresses that are currently // tentative. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { return tcpip.ErrAddressFamilyNotSupported } - if ref.getKind() != permanentTentative { + if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should be marked as tentative since we are starting DAD. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } // Should not attempt to perform DAD on an address that is currently in the @@ -617,18 +638,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // existed, we would get an error since we attempted to add a duplicate // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. - // See NIC.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + // See endpoint.addAddressLocked. + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID())) } remaining := ndp.configs.DupAddrDetectTransmits if remaining == 0 { - ref.setKind(permanent) + addressEndpoint.SetKind(stack.Permanent) // Consider DAD to have resolved even if no DAD messages were actually // transmitted. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } return nil @@ -637,25 +658,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref var done bool var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the NIC's lock. This is effectively the same - // as starting a goroutine but we use a timer that fires immediately so we can - // reset it for the next DAD iteration. - timer = ndp.nic.stack.Clock().AfterFunc(0, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { + ndp.ep.mu.Lock() + defer ndp.ep.mu.Unlock() if done { // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the NIC lock and stopped DAD - // before this function obtained the NIC lock. Simply return here and do - // nothing further. + // another goroutine already obtained the IPv6 endpoint lock and stopped + // DAD before this function obtained the NIC lock. Simply return here and + // do nothing further. return } - if ref.getKind() != permanentTentative { + if addressEndpoint.GetKind() != stack.PermanentTentative { // The endpoint should still be marked as tentative since we are still // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } dadDone := remaining == 0 @@ -663,33 +684,34 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref var err *tcpip.Error if !dadDone { // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) // Do not hold the lock when sending packets which may be a long running // task or may block link address resolution. We know this is safe // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the NIC. Note, DAD would be - // stopped if the NIC was disabled or removed, or if the address was - // removed. - ndp.nic.mu.Unlock() - err = ndp.sendDADPacket(addr, ref) - ndp.nic.mu.Lock() + // has been stopped before doing any work with the IPv6 endpoint. Note, + // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if + // the address was removed. + ndp.ep.mu.Unlock() + err = ndp.sendDADPacket(addr, addressEndpoint) + ndp.ep.mu.Lock() + addressEndpoint.DecRef() } if done { // If we reach this point, it means that DAD was stopped after we released - // the NIC's read lock and before we obtained the write lock. + // the IPv6 endpoint's read lock and before we obtained the write lock. return } if dadDone { // DAD has resolved. - ref.setKind(permanent) + addressEndpoint.SetKind(stack.Permanent) } else if err == nil { // DAD is not done and we had no errors when sending the last NDP NS, // schedule the next DAD timer. remaining-- - timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) + timer.Reset(ndp.configs.RetransmitTimer) return } @@ -698,16 +720,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // integrator know DAD has completed. delete(ndp.dad, addr) - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } // If DAD resolved for a stable SLAAC address, attempt generation of a // temporary SLAAC address. - if dadDone && ref.configType == slaac { + if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) } }) @@ -722,28 +744,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns // addr. // -// addr must be a tentative IPv6 address on ndp's NIC. +// addr must be a tentative IPv6 address on ndp's IPv6 endpoint. // -// The NIC ndp belongs to MUST NOT be locked. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { +// The IPv6 endpoint that ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. + // Since this method is required to be called when the IPv6 endpoint is not + // locked, the NIC could have been disabled or removed by another goroutine. if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { return err } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) } icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) @@ -752,17 +777,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd ns.SetTargetAddress(addr) icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - pkt := NewPacketBuffer(PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, - NetworkHeaderParams{ + stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - TOS: DefaultTOS, }, pkt, ); err != nil { sent.Dropped.Increment() @@ -778,11 +802,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd // such a state forever, unless some other external event resolves the DAD // process (receiving an NA from the true owner of addr, or an NS for addr // (implying another node is attempting to use addr)). It is up to the caller -// of this function to handle such a scenario. Normally, addr will be removed -// from n right after this function returns or the address successfully -// resolved. +// of this function to handle such a scenario. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { dad, ok := ndp.dad[addr] if !ok { @@ -801,30 +823,30 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } // handleRA handles a Router Advertisement message that arrived on the NIC // this ndp is for. Does nothing if the NIC is configured to not handle RAs. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - // Is the NIC configured to handle RAs at all? + // Is the IPv6 endpoint configured to handle RAs at all? // // Currently, the stack does not determine router interface status on a - // per-interface basis; it is a stack-wide configuration, so we check - // stack's forwarding flag to determine if the NIC is a routing - // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) { + // per-interface basis; it is a protocol-wide configuration, so we check the + // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding + // packets. + if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { return } // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we // only inform the dispatcher on configuration changes. We do nothing else // with the information. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { var configuration DHCPv6ConfigurationFromNDPRA switch { case ra.ManagedAddrConfFlag(): @@ -839,11 +861,11 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { if ndp.dhcpv6Configuration != configuration { ndp.dhcpv6Configuration = configuration - ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration) } } - // Is the NIC configured to discover default routers? + // Is the IPv6 endpoint configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { rtr, ok := ndp.defaultRouters[ip] rl := ra.RouterLifetime() @@ -881,20 +903,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt := opt.(type) { case header.NDPRecursiveDNSServer: - if ndp.nic.stack.ndpDisp == nil { + if ndp.ep.protocol.ndpDisp == nil { continue } addrs, _ := opt.Addresses() - ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: - if ndp.nic.stack.ndpDisp == nil { + if ndp.ep.protocol.ndpDisp == nil { continue } domainNames, _ := opt.DomainNames() - ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -928,7 +950,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // invalidateDefaultRouter invalidates a discovered default router. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { rtr, ok := ndp.defaultRouters[ip] @@ -942,32 +964,32 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } // rememberDefaultRouter remembers a newly discovered default router with IPv6 // link-local address ip with lifetime rl. // -// The router identified by ip MUST NOT already be known by the NIC. +// The router identified by ip MUST NOT already be known by the IPv6 endpoint. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return } // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { + if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { // Informed by the integrator to not remember the router, do // nothing further. return } state := defaultRouterState{ - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateDefaultRouter(ip) }), } @@ -982,22 +1004,22 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The prefix identified by prefix MUST NOT already be known. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { // Informed by the integrator to not remember the prefix, do // nothing further. return } state := onLinkPrefixState{ - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } @@ -1011,7 +1033,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) // invalidateOnLinkPrefix invalidates a discovered on-link prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { s, ok := ndp.onLinkPrefixes[prefix] @@ -1025,8 +1047,8 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1036,7 +1058,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { // handleOnLinkPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the on-link flag is set. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { prefix := pi.Subnet() prefixState, ok := ndp.onLinkPrefixes[prefix] @@ -1089,7 +1111,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // handleAutonomousPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the autonomous flag is set. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { vl := pi.ValidLifetime() pl := pi.PreferredLifetime() @@ -1125,7 +1147,7 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform // // pl is the new preferred lifetime. vl is the new valid lifetime. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 @@ -1142,15 +1164,15 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(state.stableAddr.ref) + ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint) }), - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1189,7 +1211,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } // If the address is assigned (DAD resolved), generate a temporary address. - if state.stableAddr.ref.getKind() == permanent { + if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) @@ -1198,32 +1220,27 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.slaacPrefixes[prefix] = state } -// addSLAACAddr adds a SLAAC address to the NIC. +// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.nic.stack.ndpDisp + ndpDisp := ndp.ep.protocol.ndpDisp if ndpDisp == nil { return nil } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { + if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { // Informed by the integrator not to add the address. return nil } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr, - } - - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) if err != nil { - panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } - return ref + return addressEndpoint } // generateSLAACAddr generates a SLAAC address for prefix. @@ -1232,10 +1249,10 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo // // Panics if the prefix is not a SLAAC prefix or it already has an address. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.stableAddr.ref; r != nil { - panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix())) } // If we have already reached the maximum address generation attempts for the @@ -1255,11 +1272,11 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt } dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, - oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()), dadCounter, oIID.SecretKey, ) @@ -1272,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.nic.linkEP.LinkAddress() + linkAddr := ndp.ep.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { return false } @@ -1291,15 +1308,15 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt PrefixLen: validPrefixLenForAutoGen, } - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { break } state.stableAddr.localGenerationFailures++ } - if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { - state.stableAddr.ref = ref + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true } @@ -1309,10 +1326,9 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // regenerateSLAACAddr regenerates an address for a SLAAC prefix. // -// If generating a new address for the prefix fails, the prefix will be -// invalidated. +// If generating a new address for the prefix fails, the prefix is invalidated. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { state, ok := ndp.slaacPrefixes[prefix] if !ok { @@ -1332,7 +1348,7 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { // generateTempSLAACAddr generates a new temporary SLAAC address. // -// If resetGenAttempts is true, the prefix's generation counter will be reset. +// If resetGenAttempts is true, the prefix's generation counter is reset. // // Returns true if a new address was generated. func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { @@ -1353,7 +1369,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - stableAddr := prefixState.stableAddr.ref.address() + stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address now := time.Now() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary @@ -1392,7 +1408,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - // Attempt to generate a new address that is not already assigned to the NIC. + // Attempt to generate a new address that is not already assigned to the IPv6 + // endpoint. var generatedAddr tcpip.AddressWithPrefix for i := 0; ; i++ { // If we were unable to generate an address after the maximum SLAAC address @@ -1402,7 +1419,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) { break } } @@ -1410,13 +1427,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary // address with a zero preferred lifetime. The checks above ensure this // so we know the address is not deprecated. - ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) - if ref == nil { + addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */) + if addressEndpoint == nil { return false } state := tempSLAACAddrState{ - deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1427,9 +1444,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) } - ndp.deprecateSLAACAddress(tempAddrState.ref) + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) }), - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1442,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { + regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1465,8 +1482,8 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla prefixState.tempAddrs[generatedAddr.Address] = tempAddrState ndp.slaacPrefixes[prefix] = prefixState }), - createdAt: now, - ref: ref, + createdAt: now, + addressEndpoint: addressEndpoint, } state.deprecationJob.Schedule(pl) @@ -1481,7 +1498,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { state, ok := ndp.slaacPrefixes[prefix] if !ok { @@ -1496,14 +1513,14 @@ func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttemp // // pl is the new preferred lifetime. vl is the new valid lifetime. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 if deprecated { - ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) + ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint) } else { - prefixState.stableAddr.ref.deprecated = false + prefixState.stableAddr.addressEndpoint.SetDeprecated(false) } // If prefix was preferred for some finite lifetime before, cancel the @@ -1565,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // If DAD is not yet complete on the stable address, there is no need to do // work with temporary addresses. - if prefixState.stableAddr.ref.getKind() != permanent { + if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent { return } @@ -1608,9 +1625,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat newPreferredLifetime := preferredUntil.Sub(now) tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { - ndp.deprecateSLAACAddress(tempAddrState.ref) + ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) } else { - tempAddrState.ref.deprecated = false + tempAddrState.addressEndpoint.SetDeprecated(false) tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } @@ -1635,8 +1652,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // due to an update in preferred lifetime. // // If each temporay address has already been regenerated, no new temporary - // address will be generated. To ensure continuation of temporary SLAAC - // addresses, we manually try to regenerate an address here. + // address is generated. To ensure continuation of temporary SLAAC addresses, + // we manually try to regenerate an address here. if len(regenForAddr) != 0 || allAddressesRegenerated { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. @@ -1647,57 +1664,58 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } } -// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP -// dispatcher that ref has been deprecated. +// deprecateSLAACAddress marks the address as deprecated and notifies the NDP +// dispatcher that address has been deprecated. // -// deprecateSLAACAddress does nothing if ref is already deprecated. +// deprecateSLAACAddress does nothing if the address is already deprecated. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { - if ref.deprecated { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) { + if addressEndpoint.Deprecated() { return } - ref.deprecated = true - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) + addressEndpoint.SetDeprecated(true) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } // invalidateSLAACPrefix invalidates a SLAAC prefix. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.stableAddr.ref; r != nil { + ndp.cleanupSLAACPrefixResources(prefix, state) + + if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) + if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err)) } } - - ndp.cleanupSLAACPrefixResources(prefix, state) } // cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's // resources. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() { + if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address { return } if !invalidatePrefix { // If the prefix is not being invalidated, disassociate the address from the // prefix and do nothing further. - state.stableAddr.ref = nil + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil ndp.slaacPrefixes[prefix] = state return } @@ -1709,14 +1727,17 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr // // Panics if the SLAAC prefix is not known. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { // Invalidate all temporary addresses. for tempAddr, tempAddrState := range state.tempAddrs { ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) } - state.stableAddr.ref = nil + if state.stableAddr.addressEndpoint != nil { + state.stableAddr.addressEndpoint.DecRef() + state.stableAddr.addressEndpoint = nil + } state.deprecationJob.Cancel() state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) @@ -1724,12 +1745,12 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa // invalidateTempSLAACAddr invalidates a temporary SLAAC address. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { // Since we are already invalidating the address, do not invalidate the // address when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) + if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err)) } ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) @@ -1738,10 +1759,10 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary // SLAAC address's resources from ndp. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } if !invalidateAddr { @@ -1765,35 +1786,29 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's // jobs and entry. // -// The NIC that ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + tempAddrState.addressEndpoint.DecRef() + tempAddrState.addressEndpoint = nil tempAddrState.deprecationJob.Cancel() tempAddrState.invalidationJob.Cancel() tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } -// cleanupState cleans up ndp's state. -// -// If hostOnly is true, then only host-specific state will be cleaned up. +// removeSLAACAddresses removes all SLAAC addresses. // -// cleanupState MUST be called with hostOnly set to true when ndp's NIC is -// transitioning from a host to a router. This function will invalidate all -// discovered on-link prefixes, discovered routers, and auto-generated -// addresses. -// -// If hostOnly is true, then the link-local auto-generated address will not be -// invalidated as routers are also expected to generate a link-local address. +// If keepLinkLocal is false, the SLAAC generated link-local address is removed. // -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - linkLocalPrefixes := 0 + var linkLocalPrefixes int for prefix, state := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. - if hostOnly && prefix == linkLocalSubnet { + if keepLinkLocal && prefix == linkLocalSubnet { linkLocalPrefixes++ continue } @@ -1804,6 +1819,21 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) } +} + +// cleanupState cleans up ndp's state. +// +// If hostOnly is true, then only host-specific state is cleaned up. +// +// This function invalidates all discovered on-link prefixes, discovered +// routers, and auto-generated addresses. +// +// If hostOnly is true, then the link-local auto-generated address aren't +// invalidated as routers are also expected to generate a link-local address. +// +// The IPv6 endpoint that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupState(hostOnly bool) { + ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1827,7 +1857,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // -// The NIC ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. @@ -1848,27 +1878,37 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { - ndp.nic.mu.Lock() + ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { + ndp.ep.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the NIC lock and stopped solicitations. - // Simply return here and do nothing further. - ndp.nic.mu.Unlock() + // goroutine already obtained the IPv6 endpoint lock and stopped + // solicitations. Simply return here and do nothing further. + ndp.ep.mu.Unlock() return } // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) - if ref == nil { - ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) + if addressEndpoint == nil { + // Incase this ends up creating a new temporary address, we need to hold + // onto the endpoint until a route is obtained. If we decrement the + // reference count before obtaing a route, the address's resources would + // be released and attempting to obtain a route after would fail. Once a + // route is obtainted, it is safe to decrement the reference count since + // obtaining a route increments the address's reference count. + addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) } - ndp.nic.mu.Unlock() + ndp.ep.mu.Unlock() - localAddr := ref.address() - r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + localAddr := addressEndpoint.AddressWithPrefix().Address + r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) + addressEndpoint.DecRef() + if err != nil { + return + } defer r.Release() // Route should resolve immediately since @@ -1876,15 +1916,16 @@ func (ndp *ndpState) startSolicitingRouters() { // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. + // Since this method is required to be called when the IPv6 endpoint is + // not locked, the IPv6 endpoint could have been disabled or removed by + // another goroutine. if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { return } - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1907,21 +1948,20 @@ func (ndp *ndpState) startSolicitingRouters() { rs.Options().Serialize(optsSerializer) icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - pkt := NewPacketBuffer(PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, - NetworkHeaderParams{ + stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - TOS: DefaultTOS, }, pkt, ); err != nil { sent.Dropped.Increment() - log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. remaining = 0 } else { @@ -1929,19 +1969,19 @@ func (ndp *ndpState) startSolicitingRouters() { remaining-- } - ndp.nic.mu.Lock() + ndp.ep.mu.Lock() if done || remaining == 0 { ndp.rtrSolicit.timer = nil ndp.rtrSolicit.done = nil } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but - // we still reached this point, then we know the NIC + // we still reached this point, then we know the IPv6 endpoint // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } - ndp.nic.mu.Unlock() + ndp.ep.mu.Unlock() }) } @@ -1949,7 +1989,7 @@ func (ndp *ndpState) startSolicitingRouters() { // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // -// The NIC ndp belongs to MUST be locked. +// The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { if ndp.rtrSolicit.timer == nil { // Nothing to do. @@ -1965,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() { // initializeTempAddrState initializes state related to temporary SLAAC // addresses. func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index c93d1194f..9033a9ed5 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -17,6 +17,7 @@ package ipv6 import ( "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -65,10 +66,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + t.Cleanup(ep.Close) + return s, ep } +var _ NDPDispatcher = (*testNDPDispatcher)(nil) + +// testNDPDispatcher is an NDPDispatcher only allows default router discovery. +type testNDPDispatcher struct { + addr tcpip.Address +} + +func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +} + +func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { + t.addr = addr + return true +} + +func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { + t.addr = addr +} + +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { + return false +} + +func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { +} + +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { + return false +} + +func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) { +} + +func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) { +} + +func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) { +} + +func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) { +} + +func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { + var ndpDisp testNDPDispatcher + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ + NDPDisp: &ndpDisp, + })}, + }) + + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) + } + + ipv6EP := ep.(*endpoint) + ipv6EP.mu.Lock() + ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) + ipv6EP.mu.Unlock() + + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } + + ndpDisp.addr = "" + ndpEP := ep.(stack.NDPEndpoint) + ndpEP.InvalidateDefaultRouter(lladdr1) + if ndpDisp.addr != lladdr1 { + t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1) + } +} + // TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -325,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { naDst tcpip.Address }{ { - name: "Unspecified source to multicast destination", + name: "Unspecified source to solicited-node multicast destination", nsOpts: nil, nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, @@ -352,11 +437,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: false, - naSrc: nicAddr, - naDst: header.IPv6AllNodesMulticastAddress, + nsInvalid: true, }, { name: "Unspecified source with source ll option to unicast destination", diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index c9e57dc0d..d0ffc299a 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -8,6 +8,7 @@ go_library( "testutil.go", ], visibility = [ + "//pkg/tcpip/network/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", ], diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 7f1d79115..2eaeab779 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -54,8 +54,8 @@ go_template_instance( go_library( name = "stack", srcs = [ + "addressable_endpoint_state.go", "conntrack.go", - "dhcpv6configurationfromndpra_string.go", "forwarder.go", "headertype_string.go", "icmp_rate_limit.go", @@ -65,7 +65,6 @@ go_library( "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", - "ndp.go", "neighbor_cache.go", "neighbor_entry.go", "neighbor_entry_list.go", @@ -106,6 +105,7 @@ go_test( name = "stack_x_test", size = "medium", srcs = [ + "addressable_endpoint_state_test.go", "ndp_test.go", "nud_test.go", "stack_test.go", @@ -116,6 +116,7 @@ go_test( deps = [ ":stack", "//pkg/rand", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go new file mode 100644 index 000000000..4d3acab96 --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -0,0 +1,753 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) +var _ AddressableEndpoint = (*AddressableEndpointState)(nil) + +// AddressableEndpointState is an implementation of an AddressableEndpoint. +type AddressableEndpointState struct { + networkEndpoint NetworkEndpoint + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + endpoints map[tcpip.Address]*addressState + primary []*addressState + + // groups holds the mapping between group addresses and the number of times + // they have been joined. + groups map[tcpip.Address]uint32 + } +} + +// Init initializes the AddressableEndpointState with networkEndpoint. +// +// Must be called before calling any other function on m. +func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { + a.networkEndpoint = networkEndpoint + + a.mu.Lock() + defer a.mu.Unlock() + a.mu.endpoints = make(map[tcpip.Address]*addressState) + a.mu.groups = make(map[tcpip.Address]uint32) +} + +// ReadOnlyAddressableEndpointState provides read-only access to an +// AddressableEndpointState. +type ReadOnlyAddressableEndpointState struct { + inner *AddressableEndpointState +} + +// AddrOrMatching returns an endpoint for the passed address that is consisdered +// bound to the wrapped AddressableEndpointState. +// +// If addr is an exact match with an existing address, that address is returned. +// Otherwise, f is called with each address and the address that f returns true +// for is returned. +// +// Returns nil of no address matches. +func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + if ep, ok := m.inner.mu.endpoints[addr]; ok { + if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { + return ep + } + } + + for _, ep := range m.inner.mu.endpoints { + if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { + return ep + } + } + + return nil +} + +// Lookup returns the AddressEndpoint for the passed address. +// +// Returns nil if the passed address is not associated with the +// AddressableEndpointState. +func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + ep, ok := m.inner.mu.endpoints[addr] + if !ok { + return nil + } + return ep +} + +// ForEach calls f for each address pair. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + for _, ep := range m.inner.mu.endpoints { + if !f(ep) { + return + } + } +} + +// ForEachPrimaryEndpoint calls f for each primary address. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + for _, ep := range m.inner.mu.primary { + f(ep) + } +} + +// ReadOnly returns a readonly reference to a. +func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { + return ReadOnlyAddressableEndpointState{inner: a} +} + +func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.releaseAddressStateLocked(addrState) +} + +// releaseAddressState removes addrState from s's address state (primary and endpoints list). +// +// Preconditions: a.mu must be write locked. +func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) { + oldPrimary := a.mu.primary + for i, s := range a.mu.primary { + if s == addrState { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + oldPrimary[len(oldPrimary)-1] = nil + break + } + } + delete(a.mu.endpoints, addrState.addr.Address) +} + +// AddAndAcquirePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// AddAndAcquireTemporaryAddress adds a temporary address. +// +// Returns tcpip.ErrDuplicateAddress if the address exists. +// +// The temporary address's endpoint is acquired and returned. +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// addAndAcquireAddressLocked adds, acquires and returns a permanent or +// temporary address. +// +// If the addressable endpoint already has the address in a non-permanent state, +// and addAndAcquireAddressLocked is adding a permanent address, that address is +// promoted in place and its properties set to the properties provided. If the +// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// returned, regardless the kind of address that is being added. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { + // attemptAddToPrimary is false when the address is already in the primary + // address list. + attemptAddToPrimary := true + addrState, ok := a.mu.endpoints[addr.Address] + if ok { + if !permanent { + // We are adding a non-permanent address but the address exists. No need + // to go any further since we can only promote existing temporary/expired + // addresses to permanent. + return nil, tcpip.ErrDuplicateAddress + } + + addrState.mu.Lock() + if addrState.mu.kind.IsPermanent() { + addrState.mu.Unlock() + // We are adding a permanent address but a permanent address already + // exists. + return nil, tcpip.ErrDuplicateAddress + } + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr)) + } + + // We now promote the address. + for i, s := range a.mu.primary { + if s == addrState { + switch peb { + case CanBePrimaryEndpoint: + // The address is already in the primary address list. + attemptAddToPrimary = false + case FirstPrimaryEndpoint: + if i == 0 { + // The address is already first in the primary address list. + attemptAddToPrimary = false + } else { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + } + case NeverPrimaryEndpoint: + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + break + } + } + } + + if addrState == nil { + addrState = &addressState{ + addressableEndpointState: a, + addr: addr, + } + a.mu.endpoints[addr.Address] = addrState + addrState.mu.Lock() + // We never promote an address to temporary - it can only be added as such. + // If we are actaully adding a permanent address, it is promoted below. + addrState.mu.kind = Temporary + } + + // At this point we have an address we are either promoting from an expired or + // temporary address to permanent, promoting an expired address to temporary, + // or we are adding a new temporary or permanent address. + // + // The address MUST be write locked at this point. + defer addrState.mu.Unlock() + + if permanent { + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr)) + } + + // Primary addresses are biased by 1. + addrState.mu.refs++ + addrState.mu.kind = Permanent + } + // Acquire the address before returning it. + addrState.mu.refs++ + addrState.mu.deprecated = deprecated + addrState.mu.configType = configType + + if attemptAddToPrimary { + switch peb { + case NeverPrimaryEndpoint: + case CanBePrimaryEndpoint: + a.mu.primary = append(a.mu.primary, addrState) + case FirstPrimaryEndpoint: + if cap(a.mu.primary) == len(a.mu.primary) { + a.mu.primary = append([]*addressState{addrState}, a.mu.primary...) + } else { + // Shift all the endpoints by 1 to make room for the new address at the + // front. We could have just created a new slice but this saves + // allocations when the slice has capacity for the new address. + primaryCount := len(a.mu.primary) + a.mu.primary = append(a.mu.primary, nil) + if n := copy(a.mu.primary[1:], a.mu.primary); n != primaryCount { + panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount)) + } + a.mu.primary[0] = addrState + } + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + } + + return addrState, nil +} + +// RemovePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.mu.groups[addr]; ok { + panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) + } + + return a.removePermanentAddressLocked(addr) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + addrState, ok := a.mu.endpoints[addr] + if !ok { + return tcpip.ErrBadLocalAddress + } + + return a.removePermanentEndpointLocked(addrState) +} + +// RemovePermanentEndpoint removes the passed endpoint if it is associated with +// a and permanent. +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { + addrState, ok := ep.(*addressState) + if !ok || addrState.addressableEndpointState != a { + return tcpip.ErrInvalidEndpointState + } + + return a.removePermanentEndpointLocked(addrState) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { + if !addrState.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + addrState.SetKind(PermanentExpired) + a.decAddressRefLocked(addrState) + return nil +} + +// decAddressRef decrements the address's reference count and releases it once +// the reference count hits 0. +func (a *AddressableEndpointState) decAddressRef(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.decAddressRefLocked(addrState) +} + +// decAddressRefLocked is like decAddressRef but with locking requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) { + addrState.mu.Lock() + defer addrState.mu.Unlock() + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("attempted to decrease ref count for AddressEndpoint w/ addr = %s when it is already released", addrState.addr)) + } + + addrState.mu.refs-- + + if addrState.mu.refs != 0 { + return + } + + // A non-expired permanent address must not have its reference count dropped + // to 0. + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.mu.kind)) + } + + a.releaseAddressStateLocked(addrState) +} + +// MainAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.GetKind() == Permanent + }) + if ep == nil { + return tcpip.AddressWithPrefix{} + } + + addr := ep.AddressWithPrefix() + a.decAddressRefLocked(ep) + return addr +} + +// acquirePrimaryAddressRLocked returns an acquired primary address that is +// valid according to isValid. +// +// Precondition: e.mu must be read locked +func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState { + var deprecatedEndpoint *addressState + for _, ep := range a.mu.primary { + if !isValid(ep) { + continue + } + + if !ep.Deprecated() { + if ep.IncRef() { + // ep is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + a.decAddressRefLocked(deprecatedEndpoint) + deprecatedEndpoint = nil + } + + return ep + } + } else if deprecatedEndpoint == nil && ep.IncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of + // ep in case a doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, ep's reference count + // will be decremented. + deprecatedEndpoint = ep + } + } + + return deprecatedEndpoint +} + +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + a.mu.Lock() + defer a.mu.Unlock() + + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } + + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } + + return addrState + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + if err != nil { + // addAndAcquireAddressLocked only returns an error if the address is + // already assigned but we just checked above if the address exists so we + // expect no error. + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + } + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + return ep +} + +// AcquireOutgoingPrimaryAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.IsAssigned(allowExpired) + }) + + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + + return ep +} + +// PrimaryAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.primary { + // Don't include tentative, expired or temporary endpoints + // to avoid confusion and prevent the caller from using + // those. + switch ep.GetKind() { + case PermanentTentative, PermanentExpired, Temporary: + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// PermanentAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.endpoints { + if !ep.GetKind().IsPermanent() { + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// JoinGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) + if err != nil { + return false, err + } + // We have no need for the address endpoint. + a.decAddressRefLocked(ep) + } + + a.mu.groups[group] = joins + 1 + return !ok, nil +} + +// LeaveGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + return false, tcpip.ErrBadLocalAddress + } + + if joins == 1 { + a.removeGroupAddressLocked(group) + delete(a.mu.groups, group) + return true, nil + } + + a.mu.groups[group] = joins - 1 + return false, nil +} + +// IsInGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { + a.mu.RLock() + defer a.mu.RUnlock() + _, ok := a.mu.groups[group] + return ok +} + +func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { + if err := a.removePermanentAddressLocked(group); err != nil { + // removePermanentEndpointLocked would only return an error if group is + // not bound to the addressable endpoint, but we know it MUST be assigned + // since we have group in our map of groups. + panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) + } +} + +// Cleanup forcefully leaves all groups and removes all permanent addresses. +func (a *AddressableEndpointState) Cleanup() { + a.mu.Lock() + defer a.mu.Unlock() + + for group := range a.mu.groups { + a.removeGroupAddressLocked(group) + } + a.mu.groups = make(map[tcpip.Address]uint32) + + for _, ep := range a.mu.endpoints { + // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // not a permanent address. + if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) + } + } +} + +var _ AddressEndpoint = (*addressState)(nil) + +// addressState holds state for an address. +type addressState struct { + addressableEndpointState *AddressableEndpointState + addr tcpip.AddressWithPrefix + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + refs uint32 + kind AddressKind + configType AddressConfigType + deprecated bool + } +} + +// AddressWithPrefix implements AddressEndpoint. +func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { + return a.addr +} + +// GetKind implements AddressEndpoint. +func (a *addressState) GetKind() AddressKind { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.kind +} + +// SetKind implements AddressEndpoint. +func (a *addressState) SetKind(kind AddressKind) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.kind = kind +} + +// IsAssigned implements AddressEndpoint. +func (a *addressState) IsAssigned(allowExpired bool) bool { + if !a.addressableEndpointState.networkEndpoint.Enabled() { + return false + } + + switch a.GetKind() { + case PermanentTentative: + return false + case PermanentExpired: + return allowExpired + default: + return true + } +} + +// IncRef implements AddressEndpoint. +func (a *addressState) IncRef() bool { + a.mu.Lock() + defer a.mu.Unlock() + if a.mu.refs == 0 { + return false + } + + a.mu.refs++ + return true +} + +// DecRef implements AddressEndpoint. +func (a *addressState) DecRef() { + a.addressableEndpointState.decAddressRef(a) +} + +// ConfigType implements AddressEndpoint. +func (a *addressState) ConfigType() AddressConfigType { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.configType +} + +// SetDeprecated implements AddressEndpoint. +func (a *addressState) SetDeprecated(d bool) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.deprecated = d +} + +// Deprecated implements AddressEndpoint. +func (a *addressState) Deprecated() bool { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.deprecated +} diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go new file mode 100644 index 000000000..26787d0a3 --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// TestAddressableEndpointStateCleanup tests that cleaning up an addressable +// endpoint state removes permanent addresses and leaves groups. +func TestAddressableEndpointStateCleanup(t *testing.T) { + var ep fakeNetworkEndpoint + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + var s stack.AddressableEndpointState + s.Init(&ep) + + addr := tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + } + + { + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + } + // We don't need the address endpoint. + ep.DecRef() + } + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep == nil { + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) + } + ep.DecRef() + } + + group := tcpip.Address("\x02") + if added, err := s.JoinGroup(group); err != nil { + t.Fatalf("s.JoinGroup(%s): %s", group, err) + } else if !added { + t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) + } + if !s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) + } + + s.Cleanup() + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) + } + } + if s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + } +} diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 836682ea0..0cd1da11f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -196,13 +196,14 @@ type bucket struct { // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. +// +// Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := pkt.Network() + if netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol @@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { dstAddr: netHeader.DestinationAddress(), dstPort: tcpHeader.DestinationPort(), transProto: netHeader.TransportProtocol(), - netProto: header.IPv4ProtocolNumber, + netProto: pkt.NetworkProtocolNumber, }, nil } @@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -281,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt Redirec // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.MinIP - replyTID.srcPort = rt.MinPort + replyTID.srcAddr = rt.Addr + replyTID.srcPort = rt.Port var manip manipType switch hook { case Prerouting: @@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction @@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { // support cases when they are validated, e.g. when we can't offload // receive checksumming. - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacketOutput manipulates ports for packets in Output hook. @@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction @@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacket will manipulate the port and address of the packet if the @@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return } @@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1 dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, transProto: header.TCPProtocolNumber, - netProto: header.IPv4ProtocolNumber, + netProto: netProto, } conn, _ := ct.connForTID(tid) if conn == nil { diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 8d18f3c8c..4e4b00a92 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,6 +46,8 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fwdTestNetworkEndpoint struct { + AddressableEndpointState + nicID tcpip.NICID proto *fwdTestNetworkProtocol dispatcher TransportDispatcher @@ -53,12 +56,18 @@ type fwdTestNetworkEndpoint struct { var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { + return nil +} + +func (*fwdTestNetworkEndpoint) Enabled() bool { + return true } -func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (*fwdTestNetworkEndpoint) Disable() {} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { @@ -78,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr return 0 } -func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -106,7 +111,9 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu return tcpip.ErrNotSupported } -func (*fwdTestNetworkEndpoint) Close() {} +func (f *fwdTestNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() +} // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. @@ -116,6 +123,11 @@ type fwdTestNetworkProtocol struct { addrResolveDelay time.Duration onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) + + mu struct { + sync.RWMutex + forwarding bool + } } var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) @@ -145,13 +157,15 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { - return &fwdTestNetworkEndpoint{ - nicID: nicID, +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { + e := &fwdTestNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, + ep: nic.LinkEndpoint(), } + e.AddressableEndpointState.Init(e) + return e } func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { @@ -186,6 +200,21 @@ func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber return fwdTestNetNumber } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding + +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + // fwdTestPacketInfo holds all the information about an outbound packet. type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 4a521eca9..8d6d9a7f1 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -60,11 +60,11 @@ func DefaultTables() *IPTables { v4Tables: [numTables]Table{ natID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -83,9 +83,9 @@ func DefaultTables() *IPTables { }, mangleID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -101,10 +101,10 @@ func DefaultTables() *IPTables { }, filterID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -125,11 +125,11 @@ func DefaultTables() *IPTables { v6Tables: [numTables]Table{ natID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -148,9 +148,9 @@ func DefaultTables() *IPTables { }, mangleID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -166,10 +166,10 @@ func DefaultTables() *IPTables { }, filterID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -502,11 +502,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, tcpip.ErrNotConnected } - return it.connections.originalDst(epID) + return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 5f1b2af64..538c4625d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -21,78 +21,139 @@ import ( ) // AcceptTarget accepts packets. -type AcceptTarget struct{} +type AcceptTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (at *AcceptTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: at.NetworkProtocol, + } +} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } // DropTarget drops packets. -type DropTarget struct{} +type DropTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (dt *DropTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: dt.NetworkProtocol, + } +} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. -type ErrorTarget struct{} +type ErrorTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (et *ErrorTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: et.NetworkProtocol, + } +} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } // UserChainTarget marks a rule as the beginning of a user chain. type UserChainTarget struct { + // Name is the chain name. Name string + + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (uc *UserChainTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: uc.NetworkProtocol, + } } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } // ReturnTarget returns from the current chain. If the chain is a built-in, the // hook's underflow should be called. -type ReturnTarget struct{} +type ReturnTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (rt *ReturnTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: rt.NetworkProtocol, + } +} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const RedirectTargetName = "REDIRECT" + // RedirectTarget redirects the packet by modifying the destination port/IP. -// Min and Max values for IP and Ports in the struct indicate the range of -// values which can be used to redirect. +// TODO(gvisor.dev/issue/170): Other flags need to be added after we support +// them. type RedirectTarget struct { - // TODO(gvisor.dev/issue/170): Other flags need to be added after - // we support them. - // RangeProtoSpecified flag indicates single port is specified to - // redirect. - RangeProtoSpecified bool + // Addr indicates address used to redirect. + Addr tcpip.Address - // MinIP indicates address used to redirect. - MinIP tcpip.Address + // Port indicates port used to redirect. + Port uint16 - // MaxIP indicates address used to redirect. - MaxIP tcpip.Address - - // MinPort indicates port used to redirect. - MinPort uint16 + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} - // MaxPort indicates port used to redirect. - MaxPort uint16 +// ID implements Target.ID. +func (rt *RedirectTarget) ID() TargetID { + return TargetID{ + Name: RedirectTargetName, + NetworkProtocol: rt.NetworkProtocol, + } } // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -103,34 +164,35 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1) in Output and - // to primary address of the incoming interface in Prerouting. + // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // the primary address of the incoming interface in Prerouting. switch hook { case Output: - rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) - rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + } else { + rt.Addr = header.IPv6Loopback + } case Prerouting: - rt.MinIP = address - rt.MaxIP = address + rt.Addr = address default: panic("redirect target is supported only on output and prerouting hooks") } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - switch protocol := netHeader.TransportProtocol(); protocol { + switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.MinPort) + udpHeader.SetDestinationPort(rt.Port) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(protocol, length) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) @@ -139,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - // Change destination address. - netHeader.SetDestinationAddress(rt.MinIP) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + + pkt.Network().SetDestinationAddress(rt.Addr) + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 093ee6881..7b3f3e88b 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -104,8 +104,20 @@ type IPTables struct { reaperDone chan struct{} } -// A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules. +// A Table defines a set of chains and hooks into the network stack. +// +// It is a list of Rules, entry points (BuiltinChains), and error handlers +// (Underflows). As packets traverse netstack, they hit hooks. When a packet +// hits a hook, iptables compares it to Rules starting from that hook's entry +// point. So if a packet hits the Input hook, we look up the corresponding +// entry point in BuiltinChains and jump to that point. +// +// If the Rule doesn't match the packet, iptables continues to the next Rule. +// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept +// or RuleDrop) that causes the packet to stop traversing iptables. It can also +// jump to other rules or perform custom actions based on Rule.Target. +// +// Underflow Rules are invoked when a chain returns without reaching a verdict. // // +stateify savable type Table struct { @@ -260,6 +272,18 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return true } +// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header +// applies. +func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber { + switch len(fl.Src) { + case header.IPv4AddressSize: + return header.IPv4ProtocolNumber + case header.IPv6AddressSize: + return header.IPv6ProtocolNumber + } + panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src)) +} + // filterAddress returns whether addr matches the filter. func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { matches := true @@ -285,8 +309,23 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } +// A TargetID uniquely identifies a target. +type TargetID struct { + // Name is the target name as stored in the xt_entry_target struct. + Name string + + // NetworkProtocol is the protocol to which the target applies. + NetworkProtocol tcpip.NetworkProtocolNumber + + // Revision is the version of the target. + Revision uint8 +} + // A Target is the interface for taking an action for a packet. type Target interface { + // ID uniquely identifies the Target. + ID() TargetID + // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8416dbcdb..73a01c2dd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -150,10 +150,10 @@ type ndpDNSSLEvent struct { type ndpDHCPv6Event struct { nicID tcpip.NICID - configuration stack.DHCPv6ConfigurationFromNDPRA + configuration ipv6.DHCPv6ConfigurationFromNDPRA } -var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) +var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. @@ -170,7 +170,7 @@ type ndpDispatcher struct { dhcpv6ConfigurationC chan ndpDHCPv6Event } -// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ @@ -182,7 +182,7 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add } } -// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -195,7 +195,7 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip. return n.rememberRouter } -// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -206,7 +206,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip } } -// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -219,7 +219,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip return n.rememberPrefix } -// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -261,7 +261,7 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi } } -// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption. func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { if c := n.rdnssC; c != nil { c <- ndpRDNSSEvent{ @@ -274,7 +274,7 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } -// Implements stack.NDPDispatcher.OnDNSSearchListOption. +// Implements ipv6.NDPDispatcher.OnDNSSearchListOption. func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { if n.dnsslC != nil { n.dnsslC <- ndpDNSSLEvent{ @@ -285,8 +285,8 @@ func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []s } } -// Implements stack.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { +// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) { if c := n.dhcpv6ConfigurationC; c != nil { c <- ndpDHCPv6Event{ nicID, @@ -319,13 +319,12 @@ func TestDADDisabled(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPDisp: &ndpDisp, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -413,19 +412,21 @@ func TestDADResolve(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = test.retransTimer - opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits e := channelLinkWithHeaderLength{ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ipv6.NDPConfigurations{ + RetransmitTimer: test.retransTimer, + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + }, + })}, + }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -558,6 +559,26 @@ func TestDADResolve(t *testing.T) { } } +func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(tgt) + snmc := header.SolicitedNodeAddr(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) +} + // TestDADFail tests to make sure that the DAD process fails if another node is // detected to be performing DAD on the same address (receive an NS message from // a node doing DAD for the same address), or if another node is detected to own @@ -567,39 +588,19 @@ func TestDADFail(t *testing.T) { tests := []struct { name string - makeBuf func(tgt tcpip.Address) buffer.Prependable + rxPkt func(e *channel.Endpoint, tgt tcpip.Address) getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ { - "RxSolicit", - func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RxSolicit", + rxPkt: rxNDPSolicit, + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborSolicit }, }, { - "RxAdvert", - func(tgt tcpip.Address) buffer.Prependable { + name: "RxAdvert", + rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) @@ -621,11 +622,9 @@ func TestDADFail(t *testing.T) { SrcAddr: tgt, DstAddr: header.IPv6AllNodesMulticastAddress, }) - - return hdr - + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborAdvert }, }, @@ -636,16 +635,16 @@ func TestDADFail(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = time.Second * 2 + ndpConfigs := ipv6.DefaultNDPConfigurations() + ndpConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -664,13 +663,8 @@ func TestDADFail(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Receive a packet to simulate multiple nodes owning or - // attempting to own the same address. - hdr := test.makeBuf(addr1) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - }) - e.InjectInbound(header.IPv6ProtocolNumber, pkt) + // Receive a packet to simulate an address conflict. + test.rxPkt(e, addr1) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -754,18 +748,19 @@ func TestDADStop(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.NDPConfigurations{ + + ndpConfigs := ipv6.NDPConfigurations{ RetransmitTimer: time.Second, DupAddrDetectTransmits: 2, } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -815,19 +810,6 @@ func TestDADStop(t *testing.T) { } } -// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NDP configurations using an invalid NICID. -func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - }) - - // No NIC with ID 1 yet. - if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) - } -} - // TestSetNDPConfigurations tests that we can update and use per-interface NDP // configurations without affecting the default NDP configurations or other // interfaces' configurations. @@ -863,8 +845,9 @@ func TestSetNDPConfigurations(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -892,12 +875,15 @@ func TestSetNDPConfigurations(t *testing.T) { } // Update the NDP configurations on NIC(1) to use DAD. - configs := stack.NDPConfigurations{ + configs := ipv6.NDPConfigurations{ DupAddrDetectTransmits: test.dupAddrDetectTransmits, RetransmitTimer: test.retransmitTimer, } - if err := s.SetNDPConfigurations(nicID1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(configs) } // Created after updating NIC(1)'s NDP configurations @@ -1113,12 +1099,13 @@ func TestNoRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + })}, }) s.SetForwarding(ipv6.ProtocolNumber, forwarding) @@ -1151,12 +1138,13 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1192,12 +1180,13 @@ func TestRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func(addr tcpip.Address, discovered bool) { @@ -1285,7 +1274,7 @@ func TestRouterDiscovery(t *testing.T) { } // TestRouterDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), @@ -1293,12 +1282,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1306,14 +1296,14 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) - if i <= stack.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredDefaultRouters { select { case e := <-ndpDisp.routerC: if diff := checkRouterEvent(e, llAddr, true); diff != "" { @@ -1358,12 +1348,13 @@ func TestNoPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + })}, }) s.SetForwarding(ipv6.ProtocolNumber, forwarding) @@ -1399,13 +1390,14 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1445,12 +1437,13 @@ func TestPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1545,12 +1538,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1621,33 +1615,34 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // TestPrefixDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) } - optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} // Receive an RA with 2 more than the max number of discovered on-link // prefixes. - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} prefixAddr[7] = byte(i) prefix := tcpip.AddressWithPrefix{ @@ -1665,8 +1660,8 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < stack.MaxDiscoveredOnLinkPrefixes { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < ipv6.MaxDiscoveredOnLinkPrefixes { select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { @@ -1716,12 +1711,13 @@ func TestNoAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + })}, }) s.SetForwarding(ipv6.ProtocolNumber, forwarding) @@ -1749,14 +1745,14 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr(t *testing.T) { +func TestAutoGenAddr2(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1766,12 +1762,13 @@ func TestAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1876,14 +1873,14 @@ func TestAutoGenTempAddr(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := stack.MaxDesyncFactor + savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MaxDesyncFactor = time.Nanosecond prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1931,16 +1928,17 @@ func TestAutoGenTempAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2119,11 +2117,11 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond tests := []struct { name string @@ -2160,12 +2158,13 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2211,11 +2210,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = 0 + ipv6.MaxDesyncFactor = 0 prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2228,15 +2227,16 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2294,17 +2294,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2317,16 +2317,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2382,8 +2383,11 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Wait for all the temporary addresses to get invalidated. @@ -2439,17 +2443,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2462,16 +2466,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2545,9 +2550,12 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // as paased. ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) select { case e := <-ndpDisp.autoGenAddrC: @@ -2565,9 +2573,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) - } + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } @@ -2655,20 +2661,21 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: test.nicNameFromID, + }, + })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: test.nicNameFromID, - }, }) s.SetRouteTable([]tcpip.Route{{ @@ -2739,8 +2746,11 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { ndpDisp.dadC = make(chan ndpDADEvent, 2) ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits ndpConfigs.RetransmitTimer = retransmitTimer - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Do SLAAC for prefix. @@ -2754,9 +2764,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // DAD failure to restart the local generation process. addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] expectAutoGenAddrAsyncEvent(addr, newAddr) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { @@ -2794,14 +2802,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, - UseNeighborCache: useNeighborCache, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -3036,11 +3045,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -3258,12 +3267,12 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 const minVLSeconds = 1 savedIL := header.NDPInfiniteLifetime - savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL header.NDPInfiniteLifetime = savedIL }() - stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3307,12 +3316,13 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3357,11 +3367,11 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 const newMinVL = 4 - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3449,12 +3459,13 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { } e := channel.New(10, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3515,12 +3526,13 @@ func TestAutoGenAddrRemoval(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3700,12 +3712,13 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3781,18 +3794,19 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3856,11 +3870,11 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const lifetimeSeconds = 10 // Needed for the temporary address sub test. - savedMaxDesync := stack.MaxDesyncFactor + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] @@ -3938,14 +3952,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -3963,7 +3977,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, }, @@ -3977,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "Temporary address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -4029,16 +4043,17 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, }, - SecretKey: secretKey, - }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4059,9 +4074,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) expectDADEvent(t, &ndpDisp, addr.Address, false) @@ -4119,14 +4132,14 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool subnet tcpip.Subnet triggerSLAACFn func(e *channel.Endpoint) }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -4142,7 +4155,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, AutoGenAddressConflictRetries: maxRetries, @@ -4165,10 +4178,11 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4198,9 +4212,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4250,21 +4262,22 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4296,9 +4309,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Simulate a DAD conflict after some time has passed. time.Sleep(failureTimer) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4459,11 +4470,12 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4509,11 +4521,12 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4694,15 +4707,16 @@ func TestCleanupNDPState(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func() (bool, ndpRouterEvent) { @@ -4967,18 +4981,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) { t.Helper() select { case e := <-ndpDisp.dhcpv6ConfigurationC: @@ -5002,7 +5017,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Even if the first RA reports no DHCPv6 configurations are available, the // dispatcher should get an event. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) // Receiving the same update again should not result in an event to the // dispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) @@ -5011,19 +5026,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to Managed Address. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to none. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) expectNoDHCPv6Event() @@ -5031,7 +5046,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // // Note, when the M flag is set, the O flag is redundant. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) expectNoDHCPv6Event() // Even though the DHCPv6 flags are different, the effective configuration is @@ -5044,7 +5059,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() @@ -5059,7 +5074,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() } @@ -5217,12 +5232,13 @@ func TestRouterSolicitation(t *testing.T) { } } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5357,12 +5373,13 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.NDPRS()) } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 213646160..9a72bec79 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // NeighborEntry describes a neighboring device in the local network. @@ -439,7 +440,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.notifyWakersLocked() } - if e.isRouter && !flags.IsRouter { + if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { // "In those cases where the IsRouter flag changes from TRUE to FALSE as // a result of this update, the node MUST remove that router from the // Default Router List and update the Destination Cache entries for all @@ -447,9 +448,17 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // 7.3.3. This is needed to detect when a node that is used as a router // stops forwarding packets due to being configured as a host." // - RFC 4861 section 7.2.5 - e.nic.mu.Lock() - e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr) - e.nic.mu.Unlock() + // + // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 + // here. + ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + if !ok { + panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) + } + + if ndpEP, ok := ep.(NDPEndpoint); ok { + ndpEP.InvalidateDefaultRouter(e.neigh.Addr) + } } e.isRouter = flags.IsRouter diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index e530ec7ea..a265fff0a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -233,18 +234,15 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudDisp: &disp, }, } + nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ + header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) - // Stub out ndpState to verify modification of default routers. - nic.mu.ndp = ndpState{ - nic: &nic, - defaultRouters: make(map[tcpip.Address]defaultRouterState), - } - // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ nic: &nic, @@ -817,6 +815,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) + ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + e.mu.Lock() e.handlePacketQueuedLocked() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ @@ -830,9 +830,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) } - e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{ - invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}), - } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -841,8 +839,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { if got, want := e.isRouter, false; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) } - if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok { - t.Errorf("unexpected defaultRouter for %s", entryTestAddr1) + if ipv6EP.invalidatedRtr != e.neigh.Addr { + t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr) } e.mu.Unlock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 2875a5b60..6cf54cc89 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -18,7 +18,6 @@ import ( "fmt" "math/rand" "reflect" - "sort" "sync/atomic" "gvisor.dev/gvisor/pkg/sleep" @@ -28,13 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -var ipv4BroadcastAddr = tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 8 * header.IPv4AddressSize, - }, -} +var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. @@ -45,22 +38,25 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats - neigh *neighborCache + stats NICStats + neigh *neighborCache + + // The network endpoints themselves may be modified by calling the interface's + // methods, but the map reference and entries must be constant. networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 + mu struct { sync.RWMutex - enabled bool spoofing bool promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - mcastJoins map[NetworkEndpointID]uint32 // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - ndp ndpState } } @@ -84,25 +80,6 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } -// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior. -type PrimaryEndpointBehavior int - -const ( - // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. - CanBePrimaryEndpoint PrimaryEndpointBehavior = iota - - // FirstPrimaryEndpoint indicates the endpoint should be the first - // primary endpoint considered. If there are multiple endpoints with - // this behavior, the most recently-added one will be first. - FirstPrimaryEndpoint - - // NeverPrimaryEndpoint indicates the endpoint should never be a - // primary endpoint. - NeverPrimaryEndpoint -) - // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -122,19 +99,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) - nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) - nic.mu.ndp = ndpState{ - nic: nic, - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - nic.mu.ndp.initializeTempAddrState() // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -162,7 +127,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = nil - nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nud, nic, ep, stack) + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } nic.linkEP.Attach(nic) @@ -170,29 +135,32 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -// enabled returns true if n is enabled. -func (n *NIC) enabled() bool { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - return enabled +func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { + return n.networkEndpoints[proto] } -// disable disables n. +// Enabled implements NetworkInterface. +func (n *NIC) Enabled() bool { + return atomic.LoadUint32(&n.enabled) == 1 +} + +// setEnabled sets the enabled status for the NIC. // -// It undoes the work done by enable. -func (n *NIC) disable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if !enabled { - return nil +// Returns true if the enabled status was updated. +func (n *NIC) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&n.enabled, 1) == 0 } + return atomic.SwapUint32(&n.enabled, 0) == 1 +} +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() { n.mu.Lock() - err := n.disableLocked() + n.disableLocked() n.mu.Unlock() - return err } // disableLocked disables n. @@ -200,9 +168,9 @@ func (n *NIC) disable() *tcpip.Error { // It undoes the work done by enable. // // n MUST be locked. -func (n *NIC) disableLocked() *tcpip.Error { - if !n.mu.enabled { - return nil +func (n *NIC) disableLocked() { + if !n.setEnabled(false) { + return } // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be @@ -210,38 +178,9 @@ func (n *NIC) disableLocked() *tcpip.Error { // again, and applications may not know that the underlying NIC was ever // disabled. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - n.mu.ndp.stopSolicitingRouters() - n.mu.ndp.cleanupState(false /* hostOnly */) - - // Stop DAD for all the unicast IPv6 endpoints that are in the - // permanentTentative state. - for _, r := range n.mu.endpoints { - if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { - n.mu.ndp.stopDuplicateAddressDetection(addr) - } - } - - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - } - - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - - // The address may have already been removed. - if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } + for _, ep := range n.networkEndpoints { + ep.Disable() } - - n.mu.enabled = false - return nil } // enable enables n. @@ -251,162 +190,38 @@ func (n *NIC) disableLocked() *tcpip.Error { // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if enabled { - return nil - } - n.mu.Lock() defer n.mu.Unlock() - if n.mu.enabled { - return nil - } - - n.mu.enabled = true - - // Create an endpoint to receive broadcast packets on this interface. - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } - - // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts - // multicast group. Note, the IANA calls the all-hosts multicast group the - // all-systems multicast group. - if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { - return err - } - } - - // Join the IPv6 All-Nodes Multicast group if the stack is configured to - // use IPv6. This is required to ensure that this node properly receives - // and responds to the various NDP messages that are destined to the - // all-nodes multicast address. An example is the Neighbor Advertisement - // when we perform Duplicate Address Detection, or Router Advertisement - // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 - // section 4.2 for more information. - // - // Also auto-generate an IPv6 link-local address based on the NIC's - // link address if it is configured to do so. Note, each interface is - // required to have IPv6 link-local unicast address, as per RFC 4291 - // section 2.1. - _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] - if !ok { + if !n.setEnabled(true) { return nil } - // Join the All-Nodes multicast group before starting DAD as responses to DAD - // (NDP NS) messages may be sent to the All-Nodes multicast group if the - // source address of the NDP NS is the unspecified address, as per RFC 4861 - // section 7.2.4. - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - - // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent - // state. - // - // Addresses may have aleady completed DAD but in the time since the NIC was - // last enabled, other devices may have acquired the same addresses. - for _, r := range n.mu.endpoints { - addr := r.address() - if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { - continue - } - - r.setKind(permanentTentative) - if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { + for _, ep := range n.networkEndpoints { + if err := ep.Enable(); err != nil { return err } } - // Do not auto-generate an IPv6 link-local address for loopback devices. - if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { - // The valid and preferred lifetime is infinite for the auto-generated - // link-local address. - n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) - } - - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyways. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !n.stack.Forwarding(header.IPv6ProtocolNumber) { - n.mu.ndp.startSolicitingRouters() - } - return nil } -// remove detaches NIC from the link endpoint, and marks existing referenced -// network endpoints expired. This guarantees no packets between this NIC and -// the network stack. +// remove detaches NIC from the link endpoint and releases network endpoint +// resources. This guarantees no packets between this NIC and the network +// stack. func (n *NIC) remove() *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() n.disableLocked() - // TODO(b/151378115): come up with a better way to pick an error than the - // first one. - var err *tcpip.Error - - // Forcefully leave multicast groups. - for nid := range n.mu.mcastJoins { - if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { - err = tempErr - } - } - - // Remove permanent and permanentTentative addresses, so no packet goes out. - for nid, ref := range n.mu.endpoints { - switch ref.getKind() { - case permanentTentative, permanent: - if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { - err = tempErr - } - } - } - - // Release any resources the network endpoint may hold. for _, ep := range n.networkEndpoints { ep.Close() } // Detach from link endpoint, so no packet comes in. n.linkEP.Attach(nil) - - return err -} - -// becomeIPv6Router transitions n into an IPv6 router. -// -// When transitioning into an IPv6 router, host-only state (NDP discovered -// routers, discovered on-link prefixes, and auto-generated addresses) will -// be cleaned up/invalidated and NDP router solicitations will be stopped. -func (n *NIC) becomeIPv6Router() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.cleanupState(true /* hostOnly */) - n.mu.ndp.stopSolicitingRouters() -} - -// becomeIPv6Host transitions n into an IPv6 host. -// -// When transitioning into an IPv6 host, NDP router solicitations will be -// started. -func (n *NIC) becomeIPv6Host() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.startSolicitingRouters() + return nil } // setPromiscuousMode enables or disables promiscuous mode. @@ -423,7 +238,8 @@ func (n *NIC) isPromiscuousMode() bool { return rv } -func (n *NIC) isLoopback() bool { +// IsLoopback implements NetworkInterface. +func (n *NIC) IsLoopback() bool { return n.linkEP.Capabilities()&CapabilityLoopback != 0 } @@ -434,206 +250,44 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists for the given protocol and remoteAddr. If no non-deprecated -// endpoint exists, the first deprecated endpoint will be returned. -// -// If an IPv6 primary endpoint is requested, Source Address Selection (as -// defined by RFC 6724 section 5) will be performed. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { - if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 { - return n.primaryIPv6Endpoint(remoteAddr) - } - - n.mu.RLock() - defer n.mu.RUnlock() - - var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.mu.primary[protocol] { - if !r.isValidForOutgoingRLocked() { - continue - } - - if !r.deprecated { - if r.tryIncRef() { - // r is not deprecated, so return it immediately. - // - // If we kept track of a deprecated endpoint, decrement its reference - // count since it was incremented when we decided to keep track of it. - if deprecatedEndpoint != nil { - deprecatedEndpoint.decRefLocked() - deprecatedEndpoint = nil - } - - return r - } - } else if deprecatedEndpoint == nil && r.tryIncRef() { - // We prefer an endpoint that is not deprecated, but we keep track of r in - // case n doesn't have any non-deprecated endpoints. - // - // If we end up finding a more preferred endpoint, r's reference count - // will be decremented when such an endpoint is found. - deprecatedEndpoint = r - } - } - - // n doesn't have any valid non-deprecated endpoints, so return - // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated - // endpoints either). - return deprecatedEndpoint -} - -// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC -// 6724 section 5). -type ipv6AddrCandidate struct { - ref *referencedNetworkEndpoint - scope header.IPv6AddressScope -} - -// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { +// primaryAddress returns an address that can be used to communicate with +// remoteAddr. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { n.mu.RLock() - ref := n.primaryIPv6EndpointRLocked(remoteAddr) + spoofing := n.mu.spoofing n.mu.RUnlock() - return ref -} - -// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -// -// n.mu MUST be read locked. -func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] - - if len(primaryAddrs) == 0 { - return nil - } - - // Create a candidate set of available addresses we can potentially use as a - // source address. - cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) - for _, r := range primaryAddrs { - // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoingRLocked() { - continue - } - - addr := r.address() - scope, err := header.ScopeForIPv6Address(addr) - if err != nil { - // Should never happen as we got r from the primary IPv6 endpoint list and - // ScopeForIPv6Address only returns an error if addr is not an IPv6 - // address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) - } - - cs = append(cs, ipv6AddrCandidate{ - ref: r, - scope: scope, - }) - } - - remoteScope, err := header.ScopeForIPv6Address(remoteAddr) - if err != nil { - // primaryIPv6Endpoint should never be called with an invalid IPv6 address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) - } - - // Sort the addresses as per RFC 6724 section 5 rules 1-3. - // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. - sort.Slice(cs, func(i, j int) bool { - sa := cs[i] - sb := cs[j] - - // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.ref.address() == remoteAddr { - return true - } - if sb.ref.address() == remoteAddr { - return false - } - - // Prefer appropriate scope as per RFC 6724 section 5 rule 2. - if sa.scope < sb.scope { - return sa.scope >= remoteScope - } else if sb.scope < sa.scope { - return sb.scope < remoteScope - } - - // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. - if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { - // If sa is not deprecated, it is preferred over sb. - return sbDep - } - - // Prefer temporary addresses as per RFC 6724 section 5 rule 7. - if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { - return saTemp - } - - // sa and sb are equal, return the endpoint that is closest to the front of - // the primary endpoint list. - return i < j - }) - - // Return the most preferred address that can have its reference count - // incremented. - for _, c := range cs { - if r := c.ref; r.tryIncRef() { - return r - } - } - - return nil -} - -// hasPermanentAddrLocked returns true if n has a permanent (including currently -// tentative) address, addr. -func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] + ep, ok := n.networkEndpoints[protocol] if !ok { - return false + return nil } - kind := ref.getKind() - - return kind == permanent || kind == permanentTentative + return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } -type getRefBehaviour int +type getAddressBehaviour int const ( // spoofing indicates that the NIC's spoofing flag should be observed when - // getting a NIC's referenced network endpoint. - spoofing getRefBehaviour = iota + // getting a NIC's address endpoint. + spoofing getAddressBehaviour = iota // promiscuous indicates that the NIC's promiscuous flag should be observed - // when getting a NIC's referenced network endpoint. + // when getting a NIC's address endpoint. promiscuous ) -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) +func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. -func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, spoofing) +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) } -// getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. +// getAddressEpOrCreateTemp returns the address endpoint for the given protocol +// and address. // // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. @@ -641,9 +295,8 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // // If the address is the IPv4 broadcast address for an endpoint's network, that // endpoint will be returned. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { +func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { n.mu.RLock() - var spoofingOrPromiscuous bool switch tempRef { case spoofing: @@ -651,274 +304,54 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous } - - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // An endpoint with this id exists, check if it can be used and return it. - if !ref.isAssignedRLocked(spoofingOrPromiscuous) { - n.mu.RUnlock() - return nil - } - - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - } - - if protocol == header.IPv4ProtocolNumber { - if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { - n.mu.RUnlock() - return ref - } - } n.mu.RUnlock() - - if !spoofingOrPromiscuous { - return nil - } - - // Try again with the lock in exclusive mode. If we still can't get the - // endpoint, create a new "temporary" endpoint. It will only exist while - // there's a route through it. - n.mu.Lock() - ref := n.getRefOrCreateTempLocked(protocol, address, peb) - n.mu.Unlock() - return ref + return n.getAddressOrCreateTempInner(protocol, address, spoofingOrPromiscuous, peb) } -// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the -// broadcast address for the endpoint's network or an address in the endpoint's -// subnet if the NIC is a loopback interface. This matches linux behaviour. -// -// n.mu MUST be read or write locked. -func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint { - for _, ref := range n.mu.endpoints { - // Only IPv4 has a notion of broadcast addresses or considers the loopback - // interface bound to an address's whole subnet (on linux). - if ref.protocol != header.IPv4ProtocolNumber { - continue - } - - subnet := ref.addrWithPrefix().Subnet() - if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() { - return ref - } +// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean +// is passed to indicate whether or not we should generate temporary endpoints. +func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + if ep, ok := n.networkEndpoints[protocol]; ok { + return ep.AcquireAssignedAddress(address, createTemp, peb) } return nil } -/// getRefOrCreateTempLocked returns an existing endpoint for address or creates -/// and returns a temporary endpoint. -// -// If the address is the IPv4 broadcast address for an endpoint's network, that -// endpoint will be returned. -// -// n.mu must be write locked. -func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // No need to check the type as we are ok with expired endpoints at this - // point. - if ref.tryIncRef() { - return ref - } - // tryIncRef failing means the endpoint is scheduled to be removed once the - // lock is released. Remove it here so we can create a new (temporary) one. - // The removal logic waiting for the lock handles this case. - n.removeEndpointLocked(ref) - } - - if protocol == header.IPv4ProtocolNumber { - if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { - return ref - } - } - - // Add a new temporary endpoint. - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return nil - } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, temporary, static, false) - return ref -} - -// addAddressLocked adds a new protocolAddress to n. -// -// If n already has the address in a non-permanent state, and the kind given is -// permanent, that address will be promoted in place and its properties set to -// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress. -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { - // TODO(b/141022673): Validate IP addresses before adding them. - - // Sanity check. - id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.mu.endpoints[id]; ok { - // Endpoint already exists. - if kind != permanent { - return nil, tcpip.ErrDuplicateAddress - } - switch ref.getKind() { - case permanentTentative, permanent: - // The NIC already have a permanent endpoint with that address. - return nil, tcpip.ErrDuplicateAddress - case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect the new peb, - // configType and deprecated status. - if ref.tryIncRef() { - // TODO(b/147748385): Perform Duplicate Address Detection when promoting - // an IPv6 endpoint to permanent. - ref.setKind(permanent) - ref.deprecated = deprecated - ref.configType = configType - - refs := n.mu.primary[ref.protocol] - for i, r := range refs { - if r == ref { - switch peb { - case CanBePrimaryEndpoint: - return ref, nil - case FirstPrimaryEndpoint: - if i == 0 { - return ref, nil - } - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - case NeverPrimaryEndpoint: - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - return ref, nil - } - } - } - - n.insertPrimaryEndpointLocked(ref, peb) - - return ref, nil - } - // tryIncRef failing means the endpoint is scheduled to be removed once - // the lock is released. Remove it here so we can create a new - // (permanent) one. The removal logic waiting for the lock handles this - // case. - n.removeEndpointLocked(ref) - } - } - +// addAddress adds a new address to n, so that it starts accepting packets +// targeted at the given address (and network protocol). +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return nil, tcpip.ErrUnknownProtocol - } - - isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) - - // If the address is an IPv6 address and it is a permanent address, - // mark it as tentative so it goes through the DAD process if the NIC is - // enabled. If the NIC is not enabled, DAD will be started when the NIC is - // enabled. - if isIPv6Unicast && kind == permanent { - kind = permanentTentative + return tcpip.ErrUnknownProtocol } - ref := &referencedNetworkEndpoint{ - refs: 1, - addr: protocolAddress.AddressWithPrefix, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, - configType: configType, - deprecated: deprecated, + addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + if err == nil { + // We have no need for the address endpoint. + addressEndpoint.DecRef() } - - // Set up resolver if link address resolution exists for this protocol. - if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { - ref.linkCache = n.stack - ref.linkRes = linkRes - } - } - - // If we are adding an IPv6 unicast address, join the solicited-node - // multicast address. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) - if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { - return nil, err - } - } - - n.mu.endpoints[id] = ref - - n.insertPrimaryEndpointLocked(ref, peb) - - // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. - if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { - if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { - return nil, err - } - } - - return ref, nil -} - -// AddAddress adds a new address to n, so that it starts accepting packets -// targeted at the given address (and network protocol). -func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { - // Add the endpoint. - n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */) - n.mu.Unlock() - return err } -// AllAddresses returns all addresses (primary and non-primary) associated with +// allPermanentAddresses returns all permanent addresses associated with // this NIC. -func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - - addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) - for _, ref := range n.mu.endpoints { - // Don't include tentative, expired or temporary endpoints to - // avoid confusion and prevent the caller from using those. - switch ref.getKind() { - case permanentExpired, temporary: - continue +func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { + var addrs []tcpip.ProtocolAddress + for p, ep := range n.networkEndpoints { + for _, a := range ep.PermanentAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ref.protocol, - AddressWithPrefix: ref.addrWithPrefix(), - }) } return addrs } -// PrimaryAddresses returns the primary addresses associated with this NIC. -func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - +// primaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress - for proto, list := range n.mu.primary { - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints - // to avoid confusion and prevent the caller from using - // those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: proto, - AddressWithPrefix: ref.addrWithPrefix(), - }) + for p, ep := range n.networkEndpoints { + for _, a := range ep.PrimaryAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } return addrs @@ -930,147 +363,25 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { - n.mu.RLock() - defer n.mu.RUnlock() - - list, ok := n.mu.primary[proto] + ep, ok := n.networkEndpoints[proto] if !ok { return tcpip.AddressWithPrefix{} } - var deprecatedEndpoint *referencedNetworkEndpoint - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints to avoid confusion - // and prevent the caller from using those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - if !ref.deprecated { - return ref.addrWithPrefix() - } - - if deprecatedEndpoint == nil { - deprecatedEndpoint = ref - } - } - - if deprecatedEndpoint != nil { - return deprecatedEndpoint.addrWithPrefix() - } - - return tcpip.AddressWithPrefix{} -} - -// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required -// by peb. -// -// n MUST be locked. -func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { - switch peb { - case CanBePrimaryEndpoint: - n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) - case FirstPrimaryEndpoint: - n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) - } -} - -func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { - id := NetworkEndpointID{LocalAddress: r.address()} - - // Nothing to do if the reference has already been replaced with a different - // one. This happens in the case where 1) this endpoint's ref count hit zero - // and was waiting (on the lock) to be removed and 2) the same address was - // re-added in the meantime by removing this endpoint from the list and - // adding a new one. - if n.mu.endpoints[id] != r { - return - } - - if r.getKind() == permanent { - panic("Reference count dropped to zero before being removed") - } - - delete(n.mu.endpoints, id) - refs := n.mu.primary[r.protocol] - for i, ref := range refs { - if ref == r { - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - refs[len(refs)-1] = nil - break - } - } -} - -func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { - n.mu.Lock() - n.removeEndpointLocked(r) - n.mu.Unlock() -} - -func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadLocalAddress - } - - kind := r.getKind() - if kind != permanent && kind != permanentTentative { - return tcpip.ErrBadLocalAddress - } - - switch r.protocol { - case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) - default: - r.expireLocked() - return nil - } + return ep.MainAddress() } -func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { - addr := r.addrWithPrefix() - - isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) - - if isIPv6Unicast { - n.mu.ndp.stopDuplicateAddressDetection(addr.Address) - - // If we are removing an address generated via SLAAC, cleanup - // its SLAAC resources and notify the integrator. - switch r.configType { - case slaac: - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - case slaacTemp: - n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - } - } - - r.expireLocked() - - // At this point the endpoint is deleted. - - // If we are removing an IPv6 unicast address, leave the solicited-node - // multicast address. - // - // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node - // multicast group may be left by user action. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr.Address) - if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { +// removeAddress removes an address from n. +func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { + for _, ep := range n.networkEndpoints { + if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + continue + } else { return err } } - return nil -} - -// RemoveAddress removes an address from n. -func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - return n.removePermanentAddressLocked(addr) + return tcpip.ErrBadLocalAddress } func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { @@ -1121,91 +432,66 @@ func (n *NIC) clearNeighbors() *tcpip.Error { // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.joinGroupLocked(protocol, addr) -} - -// joinGroupLocked adds a new endpoint for the given multicast address, if none -// exists yet. Otherwise it just increments its count. n MUST be locked before -// joinGroupLocked is called. -func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as // outlined in RFC 3810 section 5. - id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - if joins == 0 { - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return tcpip.ErrUnknownProtocol - } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported } - n.mu.mcastJoins[id] = joins + 1 - return nil + + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + _, err := gep.JoinGroup(addr) + return err } // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.leaveGroupLocked(addr, false /* force */) -} +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported + } -// leaveGroupLocked decrements the count for the given multicast address, and -// when it reaches zero removes the endpoint for this address. n MUST be locked -// before leaveGroupLocked is called. -// -// If force is true, then the count for the multicast addres is ignored and the -// endpoint will be removed immediately. -func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { - id := NetworkEndpointID{addr} - joins, ok := n.mu.mcastJoins[id] + gep, ok := ep.(GroupAddressableEndpoint) if !ok { - // There are no joins with this address on this NIC. - return tcpip.ErrBadLocalAddress + return tcpip.ErrNotSupported } - joins-- - if force || joins == 0 { - // There are no outstanding joins or we are forced to leave, clean up. - delete(n.mu.mcastJoins, id) - return n.removePermanentAddressLocked(addr) + if _, err := gep.LeaveGroup(addr); err != nil { + return err } - n.mu.mcastJoins[id] = joins return nil } // isInGroup returns true if n has joined the multicast group addr. func (n *NIC) isInGroup(addr tcpip.Address) bool { - n.mu.RLock() - joins := n.mu.mcastJoins[NetworkEndpointID{addr}] - n.mu.RUnlock() + for _, ep := range n.networkEndpoints { + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + continue + } + + if gep.IsInGroup(addr) { + return true + } + } - return joins != 0 + return false } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) +func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { + r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + defer r.Release() r.RemoteLinkAddress = remotelinkAddr - - ref.ep.HandlePacket(&r, pkt) - ref.decRef() + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -1216,7 +502,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // the ownership of the items is not retained by the caller. func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() - enabled := n.mu.enabled + enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. if !enabled { n.mu.RUnlock() @@ -1274,17 +560,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return + if n.stack.handleLocal && !n.IsLoopback() { + if r := n.getAddress(protocol, src); r != nil { + r.DecRef() + + // The source address is one of our own, so we never should have gotten a + // packet like this unless handleLocal is false. Loopback also calls this + // function even though the packets didn't come from the physical interface + // so don't drop those. + n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() + return + } } // Loopback traffic skips the prerouting chain. - if !n.isLoopback() { + if !n.IsLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) @@ -1295,8 +585,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) + if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { + n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) return } @@ -1312,20 +602,20 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.ref.nic - n.mu.RLock() - ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() - n.mu.RUnlock() - if ok { - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote - r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, pkt) - ref.decRef() - r.Release() - return + n := r.nic + if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { + if n.isValidForOutgoing(addressEndpoint) { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote + r.RemoteAddress = src + // TODO(b/123449044): Update the source NIC as well. + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + addressEndpoint.DecRef() + r.Release() + return + } + + addressEndpoint.DecRef() } // n doesn't have a destination endpoint. @@ -1401,10 +691,8 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - // TODO(gvisor.dev/issue/4365): Let the caller know that the transport - // protocol is unrecognized. n.stack.stats.UnknownProtocolRcvdPackets.Increment() - return TransportPacketHandled + return TransportPacketProtocolUnreachable } transProto := state.proto @@ -1498,96 +786,23 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } } -// ID returns the identifier of n. +// ID implements NetworkInterface. func (n *NIC) ID() tcpip.NICID { return n.id } -// Name returns the name of n. +// Name implements NetworkInterface. func (n *NIC) Name() string { return n.name } -// Stack returns the instance of the Stack that owns this NIC. -func (n *NIC) Stack() *Stack { - return n.stack -} - -// LinkEndpoint returns the link endpoint of n. +// LinkEndpoint implements NetworkInterface. func (n *NIC) LinkEndpoint() LinkEndpoint { return n.linkEP } -// isAddrTentative returns true if addr is tentative on n. -// -// Note that if addr is not associated with n, then this function will return -// false. It will only return true if the address is associated with the NIC -// AND it is tentative. -func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - n.mu.RLock() - defer n.mu.RUnlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return false - } - - return ref.getKind() == permanentTentative -} - -// dupTentativeAddrDetected attempts to inform n that a tentative addr is a -// duplicate on a link. -// -// dupTentativeAddrDetected will remove the tentative address if it exists. If -// the address was generated via SLAAC, an attempt will be made to generate a -// new address. -func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadAddress - } - - if ref.getKind() != permanentTentative { - return tcpip.ErrInvalidEndpointState - } - - // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a - // new address will be generated for it. - if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { - return err - } - - prefix := ref.addrWithPrefix().Subnet() - - switch ref.configType { - case slaac: - n.mu.ndp.regenerateSLAACAddr(prefix) - case slaacTemp: - // Do not reset the generation attempts counter for the prefix as the - // temporary address is being regenerated in response to a DAD conflict. - n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) - } - - return nil -} - -// setNDPConfigs sets the NDP configurations for n. -// -// Note, if c contains invalid NDP configuration values, it will be fixed to -// use default values for the erroneous values. -func (n *NIC) setNDPConfigs(c NDPConfigurations) { - c.validate() - - n.mu.Lock() - n.mu.ndp.configs = c - n.mu.Unlock() -} - -// NUDConfigs gets the NUD configurations for n. -func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) { +// nudConfigs gets the NUD configurations for n. +func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { return NUDConfigurations{}, tcpip.ErrNotSupported } @@ -1607,49 +822,6 @@ func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { return nil } -// handleNDPRA handles an NDP Router Advertisement message that arrived on n. -func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.handleRA(ip, ra) -} - -type networkEndpointKind int32 - -const ( - // A permanentTentative endpoint is a permanent address that is not yet - // considered to be fully bound to an interface in the traditional - // sense. That is, the address is associated with a NIC, but packets - // destined to the address MUST NOT be accepted and MUST be silently - // dropped, and the address MUST NOT be used as a source address for - // outgoing packets. For IPv6, addresses will be of this kind until - // NDP's Duplicate Address Detection has resolved, or be deleted if - // the process results in detecting a duplicate address. - permanentTentative networkEndpointKind = iota - - // A permanent endpoint is created by adding a permanent address (vs. a - // temporary one) to the NIC. Its reference count is biased by 1 to avoid - // removal when no route holds a reference to it. It is removed by explicitly - // removing the permanent address from the NIC. - permanent - - // An expired permanent endpoint is a permanent endpoint that had its address - // removed from the NIC, and it is waiting to be removed once no more routes - // hold a reference to it. This is achieved by decreasing its reference count - // by 1. If its address is re-added before the endpoint is removed, its type - // changes back to permanent and its reference count increases by 1 again. - permanentExpired - - // A temporary endpoint is created for spoofing outgoing packets, or when in - // promiscuous mode and accepting incoming packets that don't match any - // permanent endpoint. Its reference count is not biased by 1 and the - // endpoint is removed immediately when no more route holds a reference to - // it. A temporary endpoint can be promoted to permanent if its address - // is added permanently. - temporary -) - func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1680,153 +852,12 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } -type networkEndpointConfigType int32 - -const ( - // A statically configured endpoint is an address that was added by - // some user-specified action (adding an explicit address, joining a - // multicast group). - static networkEndpointConfigType = iota - - // A SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4862 section 5.5.3. - slaac - - // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are - // not expected to be valid (or preferred) forever; hence the term temporary. - slaacTemp -) - -type referencedNetworkEndpoint struct { - ep NetworkEndpoint - addr tcpip.AddressWithPrefix - nic *NIC - protocol tcpip.NetworkProtocolNumber - - // linkCache is set if link address resolution is enabled for this - // protocol. Set to nil otherwise. - linkCache LinkAddressCache - - // linkRes is set if link address resolution is enabled for this protocol. - // Set to nil otherwise. - linkRes LinkAddressResolver - - // refs is counting references held for this endpoint. When refs hits zero it - // triggers the automatic removal of the endpoint from the NIC. - refs int32 - - // networkEndpointKind must only be accessed using {get,set}Kind(). - kind networkEndpointKind - - // configType is the method that was used to configure this endpoint. - // This must never change except during endpoint creation and promotion to - // permanent. - configType networkEndpointConfigType - - // deprecated indicates whether or not the endpoint should be considered - // deprecated. That is, when deprecated is true, other endpoints that are not - // deprecated should be preferred. - deprecated bool -} - -func (r *referencedNetworkEndpoint) address() tcpip.Address { - return r.addr.Address -} - -func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { - return r.addr -} - -func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { - return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) -} - -func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { - atomic.StoreInt32((*int32)(&r.kind), int32(kind)) -} - // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address) // has been removed) unless the NIC is in spoofing mode, or temporary. -func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - r.nic.mu.RLock() - defer r.nic.mu.RUnlock() - - return r.isValidForOutgoingRLocked() -} - -// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires -// r.nic.mu to be read locked. -func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - if !r.nic.mu.enabled { - return false - } - - return r.isAssignedRLocked(r.nic.mu.spoofing) -} - -// isAssignedRLocked returns true if r is considered to be assigned to the NIC. -// -// r.nic.mu must be read locked. -func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { - switch r.getKind() { - case permanentTentative: - return false - case permanentExpired: - return spoofingOrPromiscuous - default: - return true - } -} - -// expireLocked decrements the reference count and marks the permanent endpoint -// as expired. -func (r *referencedNetworkEndpoint) expireLocked() { - r.setKind(permanentExpired) - r.decRefLocked() -} - -// decRef decrements the ref count and cleans up the endpoint once it reaches -// zero. -func (r *referencedNetworkEndpoint) decRef() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpoint(r) - } -} - -// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpointLocked(r) - } -} - -// incRef increments the ref count. It must only be called when the caller is -// known to be holding a reference to the endpoint, otherwise tryIncRef should -// be used. -func (r *referencedNetworkEndpoint) incRef() { - atomic.AddInt32(&r.refs, 1) -} - -// tryIncRef attempts to increment the ref count from n to n+1, but only if n is -// not zero. That is, it will increment the count if the endpoint is still -// alive, and do nothing if it has already been clean up. -func (r *referencedNetworkEndpoint) tryIncRef() bool { - for { - v := atomic.LoadInt32(&r.refs) - if v == 0 { - return false - } - - if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { - return true - } - } -} - -// stack returns the Stack instance that owns the underlying endpoint. -func (r *referencedNetworkEndpoint) stack() *Stack { - return r.nic.stack +func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + return n.Enabled() && ep.IsAssigned(spoofing) } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index bc9c9881a..fdd49b77f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,96 +15,40 @@ package stack import ( - "math" "testing" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) -var _ LinkEndpoint = (*testLinkEndpoint)(nil) +var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) +var _ NDPEndpoint = (*testIPv6Endpoint)(nil) -// A LinkEndpoint that throws away outgoing packets. +// An IPv6 NetworkEndpoint that throws away outgoing packets. // -// We use this instead of the channel endpoint as the channel package depends on +// We use this instead of ipv6.endpoint because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. -type testLinkEndpoint struct { - dispatcher NetworkDispatcher -} - -// Attach implements LinkEndpoint.Attach. -func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements LinkEndpoint.IsAttached. -func (e *testLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements LinkEndpoint.MTU. -func (*testLinkEndpoint) MTU() uint32 { - return math.MaxUint16 -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { - return CapabilityResolutionRequired -} +type testIPv6Endpoint struct { + AddressableEndpointState -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*testLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} + nicID tcpip.NICID + linkEP LinkEndpoint + protocol *testIPv6Protocol -// LinkAddress returns the link address of this endpoint. -func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" + invalidatedRtr tcpip.Address } -// Wait implements LinkEndpoint.Wait. -func (*testLinkEndpoint) Wait() {} - -// WritePacket implements LinkEndpoint.WritePacket. -func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) Enable() *tcpip.Error { return nil } -// WritePackets implements LinkEndpoint.WritePackets. -func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported -} - -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") +func (*testIPv6Endpoint) Enabled() bool { + return true } -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - panic("not implemented") -} - -var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) - -// An IPv6 NetworkEndpoint that throws away outgoing packets. -// -// We use this instead of ipv6.endpoint because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Endpoint struct { - nicID tcpip.NICID - linkEP LinkEndpoint - protocol *testIPv6Protocol -} +func (*testIPv6Endpoint) Disable() {} // DefaultTTL implements NetworkEndpoint.DefaultTTL. func (*testIPv6Endpoint) DefaultTTL() uint8 { @@ -116,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 { return e.linkEP.MTU() - header.IPv6MinimumSize } -// Capabilities implements NetworkEndpoint.Capabilities. -func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize @@ -144,23 +83,24 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip return tcpip.ErrNotSupported } -// NICID implements NetworkEndpoint.NICID. -func (e *testIPv6Endpoint) NICID() tcpip.NICID { - return e.nicID -} - // HandlePacket implements NetworkEndpoint.HandlePacket. func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { } // Close implements NetworkEndpoint.Close. -func (*testIPv6Endpoint) Close() {} +func (e *testIPv6Endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} // NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } +func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.invalidatedRtr = rtr +} + var _ NetworkProtocol = (*testIPv6Protocol)(nil) // An IPv6 NetworkProtocol that supports the bare minimum to make a stack @@ -192,12 +132,14 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { - return &testIPv6Endpoint{ - nicID: nicID, - linkEP: linkEP, +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { + e := &testIPv6Endpoint{ + nicID: nic.ID(), + linkEP: nic.LinkEndpoint(), protocol: p, } + e.AddressableEndpointState.Init(e) + return e } // SetOption implements NetworkProtocol.SetOption. @@ -241,42 +183,6 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd return "", false } -func newTestIPv6Protocol(*Stack) NetworkProtocol { - return &testIPv6Protocol{} -} - -// Test the race condition where a NIC is removed and an RS timer fires at the -// same time. -func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { - const ( - nicID = 1 - - maxRtrSolicitations = 5 - ) - - e := testLinkEndpoint{} - s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{newTestIPv6Protocol}, - NDPConfigs: NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: minimumRtrSolicitationInterval, - }, - }) - - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) - } - - s.mu.Lock() - // Wait for the router solicitation timer to fire and block trying to obtain - // the stack lock when doing link address resolution. - time.Sleep(minimumRtrSolicitationInterval * 2) - if err := s.removeNICLocked(nicID); err != nil { - t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) - } - s.mu.Unlock() -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index a7d9d59fa..105583c49 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) type headerType int @@ -255,6 +256,20 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { return newPk } +// Network returns the network header as a header.Network. +// +// Network should only be called when NetworkHeader has been set. +func (pk *PacketBuffer) Network() header.Network { + switch netProto := pk.NetworkProtocolNumber; netProto { + case header.IPv4ProtocolNumber: + return header.IPv4(pk.NetworkHeader().View()) + case header.IPv6ProtocolNumber: + return header.IPv6(pk.NetworkHeader().View()) + default: + panic(fmt.Sprintf("unknown network protocol number %d", netProto)) + } +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 780a5ebde..be9bd8042 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -152,10 +154,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -206,6 +208,10 @@ const ( // transport layer and callers need not take any further action. TransportPacketHandled TransportPacketDisposition = iota + // TransportPacketProtocolUnreachable indicates that the transport + // protocol requested in the packet is not supported. + TransportPacketProtocolUnreachable + // TransportPacketDestinationPortUnreachable indicates that there weren't any // listeners interested in the packet and the transport protocol has no means // to notify the sender. @@ -259,9 +265,253 @@ type NetworkHeaderParams struct { TOS uint8 } +// GroupAddressableEndpoint is an endpoint that supports group addressing. +// +// An endpoint is considered to support group addressing when one or more +// endpoints may associate themselves with the same identifier (group address). +type GroupAddressableEndpoint interface { + // JoinGroup joins the spcified group. + // + // Returns true if the group was newly joined. + JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + + // LeaveGroup attempts to leave the specified group. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. + LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + + // IsInGroup returns true if the endpoint is a member of the specified group. + IsInGroup(group tcpip.Address) bool +} + +// PrimaryEndpointBehavior is an enumeration of an AddressEndpoint's primary +// behavior. +type PrimaryEndpointBehavior int + +const ( + // CanBePrimaryEndpoint indicates the endpoint can be used as a primary + // endpoint for new connections with no local address. This is the + // default when calling NIC.AddAddress. + CanBePrimaryEndpoint PrimaryEndpointBehavior = iota + + // FirstPrimaryEndpoint indicates the endpoint should be the first + // primary endpoint considered. If there are multiple endpoints with + // this behavior, they are ordered by recency. + FirstPrimaryEndpoint + + // NeverPrimaryEndpoint indicates the endpoint should never be a + // primary endpoint. + NeverPrimaryEndpoint +) + +// AddressConfigType is the method used to add an address. +type AddressConfigType int + +const ( + // AddressConfigStatic is a statically configured address endpoint that was + // added by some user-specified action (adding an explicit address, joining a + // multicast group). + AddressConfigStatic AddressConfigType = iota + + // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862 + // section 5.5.3. + AddressConfigSlaac + + // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as + // per RFC 4941. Temporary SLAAC addresses are short-lived and are not + // to be valid (or preferred) forever; hence the term temporary. + AddressConfigSlaacTemp +) + +// AssignableAddressEndpoint is a reference counted address endpoint that may be +// assigned to a NetworkEndpoint. +type AssignableAddressEndpoint interface { + // AddressWithPrefix returns the endpoint's address. + AddressWithPrefix() tcpip.AddressWithPrefix + + // IsAssigned returns whether or not the endpoint is considered bound + // to its NetworkEndpoint. + IsAssigned(allowExpired bool) bool + + // IncRef increments this endpoint's reference count. + // + // Returns true if it was successfully incremented. If it returns false, then + // the endpoint is considered expired and should no longer be used. + IncRef() bool + + // DecRef decrements this endpoint's reference count. + DecRef() +} + +// AddressEndpoint is an endpoint representing an address assigned to an +// AddressableEndpoint. +type AddressEndpoint interface { + AssignableAddressEndpoint + + // GetKind returns the address kind for this endpoint. + GetKind() AddressKind + + // SetKind sets the address kind for this endpoint. + SetKind(AddressKind) + + // ConfigType returns the method used to add the address. + ConfigType() AddressConfigType + + // Deprecated returns whether or not this endpoint is deprecated. + Deprecated() bool + + // SetDeprecated sets this endpoint's deprecated status. + SetDeprecated(bool) +} + +// AddressKind is the kind of of an address. +// +// See the values of AddressKind for more details. +type AddressKind int + +const ( + // PermanentTentative is a permanent address endpoint that is not yet + // considered to be fully bound to an interface in the traditional + // sense. That is, the address is associated with a NIC, but packets + // destined to the address MUST NOT be accepted and MUST be silently + // dropped, and the address MUST NOT be used as a source address for + // outgoing packets. For IPv6, addresses are of this kind until NDP's + // Duplicate Address Detection (DAD) resolves. If DAD fails, the address + // is removed. + PermanentTentative AddressKind = iota + + // Permanent is a permanent endpoint (vs. a temporary one) assigned to the + // NIC. Its reference count is biased by 1 to avoid removal when no route + // holds a reference to it. It is removed by explicitly removing the address + // from the NIC. + Permanent + + // PermanentExpired is a permanent endpoint that had its address removed from + // the NIC, and it is waiting to be removed once no references to it are held. + // + // If the address is re-added before the endpoint is removed, its type + // changes back to Permanent. + PermanentExpired + + // Temporary is an endpoint, created on a one-off basis to temporarily + // consider the NIC bound an an address that it is not explictiy bound to + // (such as a permanent address). Its reference count must not be biased by 1 + // so that the address is removed immediately when references to it are no + // longer held. + // + // A temporary endpoint may be promoted to permanent if the address is added + // permanently. + Temporary +) + +// IsPermanent returns true if the AddressKind represents a permanent address. +func (k AddressKind) IsPermanent() bool { + switch k { + case Permanent, PermanentTentative: + return true + case Temporary, PermanentExpired: + return false + default: + panic(fmt.Sprintf("unrecognized address kind = %d", k)) + } +} + +// AddressableEndpoint is an endpoint that supports addressing. +// +// An endpoint is considered to support addressing when the endpoint may +// associate itself with an identifier (address). +type AddressableEndpoint interface { + // AddAndAcquirePermanentAddress adds the passed permanent address. + // + // Returns tcpip.ErrDuplicateAddress if the address exists. + // + // Acquires and returns the AddressEndpoint for the added address. + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + + // RemovePermanentAddress removes the passed address if it is a permanent + // address. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // permanent address. + RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + + // MainAddress returns the endpoint's primary permanent address. + MainAddress() tcpip.AddressWithPrefix + + // AcquireAssignedAddress returns an address endpoint for the passed address + // that is considered bound to the endpoint, optionally creating a temporary + // endpoint if requested and no existing address exists. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if the specified address is not local to this endpoint. + AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint + + // AcquireOutgoingPrimaryAddress returns a primary address that may be used as + // a source address when sending packets to the passed remote address. + // + // If allowExpired is true, expired addresses may be returned. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if a primary address is not available. + AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint + + // PrimaryAddresses returns the primary addresses. + PrimaryAddresses() []tcpip.AddressWithPrefix + + // PermanentAddresses returns all the permanent addresses. + PermanentAddresses() []tcpip.AddressWithPrefix +} + +// NDPEndpoint is a network endpoint that supports NDP. +type NDPEndpoint interface { + NetworkEndpoint + + // InvalidateDefaultRouter invalidates a default router discovered through + // NDP. + InvalidateDefaultRouter(tcpip.Address) +} + +// NetworkInterface is a network interface. +type NetworkInterface interface { + // ID returns the interface's ID. + ID() tcpip.NICID + + // IsLoopback returns true if the interface is a loopback interface. + IsLoopback() bool + + // Name returns the name of the interface. + // + // May return an empty string if the interface is not configured with a name. + Name() string + + // Enabled returns true if the interface is enabled. + Enabled() bool + + // LinkEndpoint returns the link endpoint backing the interface. + LinkEndpoint() LinkEndpoint +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { + AddressableEndpoint + + // Enable enables the endpoint. + // + // Must only be called when the stack is in a state that allows the endpoint + // to send and receive packets. + // + // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() *tcpip.Error + + // Enabled returns true if the endpoint is enabled. + Enabled() bool + + // Disable disables the endpoint. + Disable() + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) // for this endpoint. DefaultTTL() uint8 @@ -271,10 +521,6 @@ type NetworkEndpoint interface { // minus the network endpoint max header length. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // underlying link-layer endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the network (and lower // level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -295,9 +541,6 @@ type NetworkEndpoint interface { // header to the given destination address. It takes ownership of pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error - // NICID returns the id of the NIC this endpoint belongs to. - NICID() tcpip.NICID - // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. // @@ -312,6 +555,17 @@ type NetworkEndpoint interface { NetworkProtocolNumber() tcpip.NetworkProtocolNumber } +// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. +type ForwardingNetworkProtocol interface { + NetworkProtocol + + // Forwarding returns the forwarding configuration. + Forwarding() bool + + // SetForwarding sets the forwarding configuration. + SetForwarding(bool) +} + // NetworkProtocol is the interface that needs to be implemented by network // protocols (e.g., ipv4, ipv6) that want to be part of the networking stack. type NetworkProtocol interface { @@ -331,7 +585,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint + NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -460,8 +714,8 @@ type LinkEndpoint interface { // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // - // Attach will be called with a nil dispatcher if the receiver's associated - // NIC is being removed. + // Attach is called with a nil dispatcher when the endpoint's NIC is being + // removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 2cbbf0de8..cc39c9a6a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -42,17 +42,27 @@ type Route struct { // NetProto is the network-layer protocol. NetProto tcpip.NetworkProtocolNumber - // ref a reference to the network endpoint through which the route - // starts. - ref *referencedNetworkEndpoint - // Loop controls where WritePacket should send packets. Loop PacketLooping + + // nic is the NIC the route goes through. + nic *NIC + + // addressEndpoint is the local address this route is associated with. + addressEndpoint AssignableAddressEndpoint + + // linkCache is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkCache LinkAddressCache + + // linkRes is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkRes LinkAddressResolver } // makeRoute initializes a new route. It takes ownership of the provided -// reference to a network endpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route { +// AssignableAddressEndpoint. +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { loop := PacketOut if handleLocal && localAddr != "" && remoteAddr == localAddr { loop = PacketLoop @@ -62,29 +72,40 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } - return Route{ + linkEP := nic.LinkEndpoint() + r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: localLinkAddr, + LocalLinkAddress: linkEP.LinkAddress(), RemoteAddress: remoteAddr, - ref: ref, + addressEndpoint: addressEndpoint, + nic: nic, Loop: loop, } + + if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { + r.linkRes = linkRes + r.linkCache = nic.stack + } + } + + return r } // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.ref.ep.NICID() + return r.nic.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.ref.ep.MaxHeaderLength() + return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.ref.nic.stack.Stats() + return r.nic.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -95,12 +116,12 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.ref.ep.Capabilities() + return r.nic.LinkEndpoint().Capabilities() } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.ref.ep.(GSOEndpoint); ok { + if gso, ok := r.nic.getNetworkEndpoint(r.NetProto).(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -138,8 +159,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { nextAddr = r.RemoteAddress } - if r.ref.nic.neigh != nil { - entry, ch, err := r.ref.nic.neigh.entry(nextAddr, r.LocalAddress, r.ref.linkRes, waker) + if neigh := r.nic.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) if err != nil { return ch, err } @@ -147,7 +168,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { return nil, nil } - linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) if err != nil { return ch, err } @@ -162,12 +183,12 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if r.ref.nic.neigh != nil { - r.ref.nic.neigh.removeWaker(nextAddr, waker) + if neigh := r.nic.neigh; neigh != nil { + neigh.removeWaker(nextAddr, waker) return } - r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) } // IsResolutionRequired returns true if Resolve() must be called to resolve @@ -175,104 +196,88 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // // The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.ref.nic.neigh != nil { - return r.ref.isValidForOutgoing() && r.ref.linkRes != nil && r.RemoteLinkAddress == "" + if r.nic.neigh != nil { + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" } - return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() - err := r.ref.ep.WritePacket(r, gso, params, pkt) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil { + return err } - return err + + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return 0, tcpip.ErrInvalidEndpointState } - // WritePackets takes ownership of pkt, calculate length first. - numPkts := pkts.Len() - - n, err := r.ref.ep.WritePackets(r, gso, pkts, params) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) - } - r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - + n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Data.Size() - if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil { return err } - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.ref.ep.DefaultTTL() + return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.ref.ep.MTU() -} - -// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying -// network endpoint. -func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return r.ref.ep.NetworkProtocolNumber() + return r.nic.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.ref != nil { - r.ref.decRef() - r.ref = nil + if r.addressEndpoint != nil { + r.addressEndpoint.DecRef() + r.addressEndpoint = nil } } -// Clone Clone a route such that the original one can be released and the new -// one will remain valid. +// Clone clones the route. func (r *Route) Clone() Route { - if r.ref != nil { - r.ref.incRef() + if r.addressEndpoint != nil { + _ = r.addressEndpoint.IncRef() } return *r } @@ -296,7 +301,7 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.ref.stack() + return r.nic.stack } func (r *Route) isV4Broadcast(addr tcpip.Address) bool { @@ -304,7 +309,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.ref.addrWithPrefix().Subnet() + subnet := r.addressEndpoint.AddressWithPrefix().Subnet() return subnet.IsBroadcast(addr) } @@ -330,7 +335,10 @@ func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { LocalLinkAddress: r.RemoteLinkAddress, RemoteAddress: src, RemoteLinkAddress: r.LocalLinkAddress, - ref: r.ref, Loop: r.Loop, + addressEndpoint: r.addressEndpoint, + nic: r.nic, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c22633f6b..0bf20c0e1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -363,38 +363,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } -// NICNameFromID is a function that returns a stable name for the specified NIC, -// even if different NIC IDs are used to refer to the same NIC in different -// program runs. It is used when generating opaque interface identifiers (IIDs). -// If the NIC was created with a name, it will be passed to NICNameFromID. -// -// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are -// generated for the same prefix on differnt NICs. -type NICNameFromID func(tcpip.NICID, string) string - -// OpaqueInterfaceIdentifierOptions holds the options related to the generation -// of opaque interface indentifiers (IIDs) as defined by RFC 7217. -type OpaqueInterfaceIdentifierOptions struct { - // NICNameFromID is a function that returns a stable name for a specified NIC, - // even if the NIC ID changes over time. - // - // Must be specified to generate the opaque IID. - NICNameFromID NICNameFromID - - // SecretKey is a pseudo-random number used as the secret key when generating - // opaque IIDs as defined by RFC 7217. The key SHOULD be at least - // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness - // requirements for security as outlined by RFC 4086. SecretKey MUST NOT - // change between program runs, unless explicitly changed. - // - // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey - // MUST NOT be modified after Stack is created. - // - // May be nil, but a nil value is highly discouraged to maintain - // some level of randomness between nodes. - SecretKey []byte -} - // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -402,13 +370,6 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver - // forwarding contains the whether packet forwarding is enabled or not for - // different network protocols. - forwarding struct { - sync.RWMutex - protocols map[tcpip.NetworkProtocolNumber]bool - } - // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. rawFactory RawFactory @@ -461,9 +422,6 @@ type Stack struct { // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 - // ndpConfigs is the default NDP configurations used by interfaces. - ndpConfigs NDPConfigurations - // nudConfigs is the default NUD configurations used by interfaces. nudConfigs NUDConfigurations @@ -471,15 +429,6 @@ type Stack struct { // by the NIC's neighborCache instead of linkAddrCache. useNeighborCache bool - // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool - - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher - // nudDisp is the NUD event dispatcher that is used to send the netstack // integrator NUD related events. nudDisp NUDDispatcher @@ -487,14 +436,6 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue @@ -553,13 +494,6 @@ type Options struct { // UniqueID is an optional generator of unique identifiers. UniqueID UniqueID - // NDPConfigs is the default NDP configurations used by interfaces. - // - // By default, NDPConfigs will have a zero value for its - // DupAddrDetectTransmits field, implying that DAD will not be performed - // before assigning an address to a NIC. - NDPConfigs NDPConfigurations - // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations @@ -570,24 +504,6 @@ type Options struct { // and ClearNeighbors. UseNeighborCache bool - // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. - // - // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address Detection - // is enabled, an address will only be assigned if it successfully resolves. - // If it fails, no further attempt will be made to auto-generate an IPv6 - // link-local address. - // - // The generated link-local address will follow RFC 4291 Appendix A - // guidelines. - AutoGenIPv6LinkLocal bool - - // NDPDisp is the NDP event dispatcher that an integrator can provide to - // receive NDP related events. - NDPDisp NDPDispatcher - // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -596,31 +512,12 @@ type Options struct { // this is non-nil. RawFactory RawFactory - // OpaqueIIDOpts hold the options for generating opaque interface - // identifiers (IIDs) as outlined by RFC 7217. - OpaqueIIDOpts OpaqueInterfaceIdentifierOptions - // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data // returned by rand.Read(). // // RandSource must be thread-safe. RandSource mathrand.Source - - // TempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - // - // Temporary SLAAC adresses are short-lived addresses which are unpredictable - // and random from the perspective of other nodes on the network. It is - // recommended that the seed be a random byte buffer of at least - // header.IIDSize bytes to make sure that temporary SLAAC addresses are - // sufficiently random. It should follow minimum randomness requirements for - // security as outlined by RFC 4086. - // - // Note: using a nil value, the same seed across netstack program runs, or a - // seed that is too small would reduce randomness and increase predictability, - // defeating the purpose of temporary SLAAC addresses. - TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -723,36 +620,28 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } - // Make sure opts.NDPConfigs contains valid values only. - opts.NDPConfigs.validate() - opts.NUDConfigs.resetInvalidFields() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: DefaultTables(), - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - ndpConfigs: opts.NDPConfigs, - nudConfigs: opts.NUDConfigs, - useNeighborCache: opts.UseNeighborCache, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, - uniqueIDGenerator: opts.UniqueID, - ndpDisp: opts.NDPDisp, - nudDisp: opts.NUDDisp, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - forwarder: newForwardQueue(), - randomGenerator: mathrand.New(randSrc), + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: DefaultTables(), + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + useNeighborCache: opts.UseNeighborCache, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -764,7 +653,6 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } - s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool) // Add specified network protocols. for _, netProtoFactory := range opts.NetworkProtocols { @@ -884,42 +772,37 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables packet forwarding between NICs. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { - s.forwarding.Lock() - defer s.forwarding.Unlock() - - // If this stack does not support the protocol, do nothing. - if _, ok := s.networkProtocols[protocol]; !ok { - return +// SetForwarding enables or disables packet forwarding between NICs for the +// passed protocol. +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return tcpip.ErrUnknownProtocol } - // If the forwarding value for this protocol hasn't changed then do - // nothing. - if forwarding := s.forwarding.protocols[protocol]; forwarding == enable { - return + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return tcpip.ErrNotSupported } - s.forwarding.protocols[protocol] = enable + forwardingProtocol.SetForwarding(enable) + return nil +} - if protocol == header.IPv6ProtocolNumber { - if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } - } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() - } - } +// Forwarding returns true if packet forwarding between NICs is enabled for the +// passed protocol. +func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return false } -} -// Forwarding returns if packet forwarding between NICs is enabled. -func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - s.forwarding.RLock() - defer s.forwarding.RUnlock() - return s.forwarding.protocols[protocol] + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return false + } + + return forwardingProtocol.Forwarding() } // SetRouteTable assigns the route table to be used by this stack. It @@ -954,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewEndpoint(s, network, waiterQueue) + return t.proto.NewEndpoint(network, waiterQueue) } // NewRawEndpoint creates a new raw transport layer endpoint of the given @@ -974,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewRawEndpoint(s, network, waiterQueue) + return t.proto.NewRawEndpoint(network, waiterQueue) } // NewPacketEndpoint creates a new packet endpoint listening for the given @@ -1081,7 +964,8 @@ func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - return nic.disable() + nic.disable() + return nil } // CheckNIC checks if a NIC is usable. @@ -1094,7 +978,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } - return nic.enabled() + return nic.Enabled() } // RemoveNIC removes NIC and all related routes from the network stack. @@ -1172,14 +1056,14 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.enabled(), + Running: nic.Enabled(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.isLoopback(), + Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.PrimaryAddresses(), + ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -1243,7 +1127,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return tcpip.ErrUnknownNICID } - return nic.AddAddress(protocolAddress, peb) + return nic.addAddress(protocolAddress, peb) } // RemoveAddress removes an existing network-layer address from the specified @@ -1253,7 +1137,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - return nic.RemoveAddress(addr) + return nic.removeAddress(addr) } return tcpip.ErrUnknownNICID @@ -1267,7 +1151,7 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) for id, nic := range s.nics { - nics[id] = nic.AllAddresses() + nics[id] = nic.allPermanentAddresses() } return nics } @@ -1289,7 +1173,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { if len(localAddr) == 0 { return nic.primaryEndpoint(netProto, remoteAddr) } @@ -1306,9 +1190,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { - return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil + if nic, ok := s.nics[id]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil } } } else { @@ -1316,20 +1200,20 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { + if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. - remoteAddr = ref.address() + remoteAddr = addressEndpoint.AddressWithPrefix().Address } - r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) + r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) if len(route.Gateway) > 0 { if needRoute { r.NextHop = route.Gateway } - } else if subnet := ref.addrWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { r.RemoteLinkAddress = header.EthernetBroadcastAddress } @@ -1367,21 +1251,20 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto return 0 } - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref == nil { + addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) + if addressEndpoint == nil { return 0 } - ref.decRef() + addressEndpoint.DecRef() return nic.id } // Go through all the NICs. for _, nic := range s.nics { - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref != nil { - ref.decRef() + if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() return nic.id } } @@ -1850,7 +1733,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { - return nic.leaveGroup(multicastAddr) + return nic.leaveGroup(protocol, multicastAddr) } return tcpip.ErrUnknownNICID } @@ -1902,53 +1785,18 @@ func (s *Stack) AllowICMPMessage() bool { return s.icmpRateLimiter.Allow() } -// IsAddrTentative returns true if addr is tentative on the NIC with ID id. -// -// Note that if addr is not associated with a NIC with id ID, then this -// function will return false. It will only return true if the address is -// associated with the NIC AND it is tentative. -func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() - - nic, ok := s.nics[id] - if !ok { - return false, tcpip.ErrUnknownNICID - } - - return nic.isAddrTentative(addr), nil -} - -// DupTentativeAddrDetected attempts to inform the NIC with ID id that a -// tentative addr on it is a duplicate on a link. -func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - - nic, ok := s.nics[id] - if !ok { - return tcpip.ErrUnknownNICID - } - - return nic.dupTentativeAddrDetected(addr) -} - -// SetNDPConfigurations sets the per-interface NDP configurations on the NIC -// with ID id to c. -// -// Note, if c contains invalid NDP configuration values, it will be fixed to -// use default values for the erroneous values. -func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { +// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol +// number installed on the specified NIC. +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() - nic, ok := s.nics[id] + nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return nil, tcpip.ErrUnknownNICID } - nic.setNDPConfigs(c) - return nil + return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. @@ -1961,7 +1809,7 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err return NUDConfigurations{}, tcpip.ErrUnknownNICID } - return nic.NUDConfigs() + return nic.nudConfigs() } // SetNUDConfigurations sets the per-interface NUD configurations. @@ -1980,22 +1828,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip return nic.setNUDConfigs(c) } -// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement -// message that it needs to handle. -func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - - nic, ok := s.nics[id] - if !ok { - return tcpip.ErrUnknownNICID - } - - nic.handleNDPRA(ip, ra) - - return nil -} - // Seed returns a 32 bit value that can be used as a seed value for port // picking, ISN generation etc. // @@ -2037,16 +1869,12 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres defer s.mu.RUnlock() for _, nic := range s.nics { - id := NetworkEndpointID{address} - - if ref, ok := nic.mu.endpoints[id]; ok { - nic.mu.RLock() - defer nic.mu.RUnlock() - - // An endpoint with this id exists, check if it can be - // used and return it. - return ref.ep, nil + addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + continue } + addressEndpoint.DecRef() + return nic.getNetworkEndpoint(netProto), nil } return nil, tcpip.ErrBadAddress } @@ -2063,3 +1891,8 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { return nic.Name() } + +// NewJob returns a new tcpip.Job using the stack's clock. +func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c205650fe..aa20f750b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -68,18 +69,41 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { + stack.AddressableEndpointState + + mu struct { + sync.RWMutex + + enabled bool + } + nicID tcpip.NICID proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint } -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = true + return nil } -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (f *fakeNetworkEndpoint) Enabled() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.enabled +} + +func (f *fakeNetworkEndpoint) Disable() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = false +} + +func (f *fakeNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fakeNetworkEndpoint) DefaultTTL() uint8 { @@ -118,10 +142,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto return 0 } -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -156,7 +176,9 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack return tcpip.ErrNotSupported } -func (*fakeNetworkEndpoint) Close() {} +func (f *fakeNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() +} // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the // number of packets sent and received via endpoints of this protocol. The index @@ -165,6 +187,11 @@ type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int defaultTTL uint8 + + mu struct { + sync.RWMutex + forwarding bool + } } func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -187,13 +214,15 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { - return &fakeNetworkEndpoint{ - nicID: nicID, +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &fakeNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, + ep: nic.LinkEndpoint(), } + e.AddressableEndpointState.Init(e) + return e } func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { @@ -231,6 +260,20 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -2213,7 +2256,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName string autoGen bool linkAddr tcpip.LinkAddress - iidOpts stack.OpaqueInterfaceIdentifierOptions + iidOpts ipv6.OpaqueInterfaceIdentifierOptions shouldGen bool expectedAddr tcpip.Address }{ @@ -2229,7 +2272,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: false, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2274,7 +2317,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: true, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2286,7 +2329,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { { name: "OIID Empty MAC and empty nicName", autoGen: true, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:1], }, @@ -2298,7 +2341,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test", autoGen: true, linkAddr: "\x01\x02\x03", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:2], }, @@ -2310,7 +2353,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test2", autoGen: true, linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:3], }, @@ -2322,7 +2365,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test3", autoGen: true, linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, }, shouldGen: true, @@ -2336,10 +2379,11 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, + })}, } e := channel.New(0, 1280, test.linkAddr) @@ -2411,15 +2455,15 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { tests := []struct { name string - opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions }{ { name: "IID From MAC", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{}, }, { name: "Opaque IID", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName }, @@ -2430,9 +2474,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + })}, } e := loopback.New() @@ -2461,12 +2506,13 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - ndpConfigs := stack.DefaultNDPConfigurations() + ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + })}, } e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2813,14 +2859,15 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDispatcher{}, + })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDispatcher{}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -3059,12 +3106,13 @@ func TestDoDADWhenNICEnabled(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + })}, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3499,3 +3547,131 @@ func TestResolveWith(t *testing.T) { t.Fatal("got r.IsResolutionRequired() = true, want = false") } } + +// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its +// associated address is removed should not cause a panic. +func TestRouteReleaseAfterAddrRemoval(t *testing.T) { + const ( + nicID = 1 + localAddr = tcpip.Address("\x01") + remoteAddr = tcpip.Address("\x02") + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err) + } + // Should not panic. + defer r.Release() + + // Check that removing the same address fails. + if err := s.RemoveAddress(nicID, localAddr); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err) + } +} + +func TestGetNetworkEndpoint(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + factories := make([]stack.NetworkProtocolFactory, 0, len(tests)) + for _, test := range tests { + factories = append(factories, test.protoFactory) + } + + s := stack.New(stack.Options{ + NetworkProtocols: factories, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep, err := s.GetNetworkEndpoint(nicID, test.protoNum) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err) + } + + if got := ep.NetworkProtocolNumber(); got != test.protoNum { + t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum) + } + }) + } +} + +func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + } + + // Check that we get the right initial address and prefix length. + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } + + // Should still get the address when the NIC is diabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 0774b5382..35e5b1a2e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -155,7 +155,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -544,9 +544,11 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return true } - // If the packet is a TCP packet with a non-unicast source or destination - // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) { + // If the packet is a TCP packet with a unspecified source or non-unicast + // destination address, then do nothing further and instruct the caller to do + // the same. The network layer handles address validation for specified source + // addresses. + if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -626,7 +628,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -681,10 +683,6 @@ func isInboundMulticastOrBroadcast(r *Route) bool { return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isInboundUnicast(r *Route) bool { - return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r) -} - -func isOutboundUnicast(r *Route) bool { - return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress) +func isSpecified(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8aae60740..62ab6d92f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -39,7 +39,7 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo - stack *stack.Stack + proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route @@ -59,8 +59,8 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Abort() { @@ -143,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) + r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return tcpip.ErrNoRoute } @@ -151,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -190,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( + if err := f.proto.stack.RegisterTransportEndpoint( a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, @@ -218,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - stack: f.stack, TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -262,6 +261,8 @@ type fakeTransportProtocolOptions struct { // fakeTransportProtocol is a transport-layer protocol descriptor. It // aggregates the number of packets received via endpoints of this protocol. type fakeTransportProtocol struct { + stack *stack.Stack + packetCount int controlCount int opts fakeTransportProtocolOptions @@ -271,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil } -func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -326,8 +327,8 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { return ok } -func fakeTransFactory(*stack.Stack) stack.TransportProtocol { - return &fakeTransportProtocol{} +func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { + return &fakeTransportProtocol{stack: s} } func TestTransportReceive(t *testing.T) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 464608dee..c42bb0991 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -237,6 +237,14 @@ type Timer interface { // network node. Or, in the case of unix endpoints, it may represent a path. type Address string +// WithPrefix returns the address with a prefix that represents a point subnet. +func (a Address) WithPrefix() AddressWithPrefix { + return AddressWithPrefix{ + Address: a, + PrefixLen: len(a) * 8, + } +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -1614,9 +1622,6 @@ type UDPStats struct { // ChecksumErrors is the number of datagrams dropped due to bad checksums. ChecksumErrors *StatCounter - - // InvalidSourceAddress is the number of invalid sourced datagrams dropped. - InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index f35dcc084..e8caf09ba 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -16,6 +16,7 @@ package integration_test import ( "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -29,6 +30,69 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) + +type ndpDispatcher struct{} + +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +} + +func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { + return false +} + +func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} + +func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { + return false +} + +func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} + +func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { + return true +} + +func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} + +func (*ndpDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {} + +func (*ndpDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {} + +func (*ndpDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {} + +func (*ndpDispatcher) OnDHCPv6Configuration(tcpip.NICID, ipv6.DHCPv6ConfigurationFromNDPRA) {} + +// TestInitialLoopbackAddresses tests that the loopback interface does not +// auto-generate a link-local address when it is brought up. +func TestInitialLoopbackAddresses(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDispatcher{}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID, nicName string) string { + t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) + return "" + }, + }, + })}, + }) + + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + nicsInfo := s.NICInfo() + if nicInfo, ok := nicsInfo[nicID]; !ok { + t.Fatalf("did not find NIC with ID = %d in s.NICInfo() = %#v", nicID, nicsInfo) + } else if got := len(nicInfo.ProtocolAddresses); got != 0 { + t.Fatalf("got len(nicInfo.ProtocolAddresses) = %d, want = 0; nicInfo.ProtocolAddresses = %#v", got, nicInfo.ProtocolAddresses) + } +} + // TestLoopbackAcceptAllInSubnet tests that a loopback interface considers // itself bound to all addresses in the subnet of an assigned address. func TestLoopbackAcceptAllInSubnet(t *testing.T) { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 72d86b5ab..4f2ca7f54 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -203,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View()) + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 7484f4ad9..87d510f96 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -37,6 +37,8 @@ const ( // protocol implements stack.TransportProtocol. type protocol struct { + stack *stack.Stack + number tcpip.TransportProtocolNumber } @@ -57,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return newEndpoint(stack, netProto, p.number, waiterQueue) + return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { if netProto != p.netProto() { return nil, tcpip.ErrUnknownProtocol } - return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) + return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } // MinimumPacketSize returns the minimum valid icmp packet size. @@ -130,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4(*stack.Stack) stack.TransportProtocol { - return &protocol{ProtocolNumber4} +func NewProtocol4(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6(*stack.Stack) stack.TransportProtocol { - return &protocol{ProtocolNumber6} +func NewProtocol6(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s, number: ProtocolNumber6} } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6891fd245..189c01c8f 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() + pkt.NetworkProtocolNumber = r.NetProto data.ReadToVV(&pkt.Data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7ad894840..ae817091a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2099,7 +2099,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID) + addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) e.UnlockUser() if err != nil { return err diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 6a3c2c32b..5bce73605 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -133,6 +133,8 @@ func (s *synRcvdCounter) Threshold() uint64 { } type protocol struct { + stack *stack.Stack + mu sync.RWMutex sackEnabled bool recovery tcpip.TCPRecovery @@ -159,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid tcp packet size. @@ -505,8 +507,9 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a TCP transport protocol. -func NewProtocol(*stack.Stack) stack.TransportProtocol { +func NewProtocol(s *stack.Stack) stack.TransportProtocol { p := protocol{ + stack: s, sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index d969ca23a..439932595 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -39,6 +39,9 @@ type rackControl struct { // sequence. fack seqnum.Value + // minRTT is the estimated minimum RTT of the connection. + minRTT time.Duration + // rtt is the RTT of the most recently delivered packet on the // connection (either cumulatively acknowledged or selectively // acknowledged) that was not marked invalid as a possible spurious @@ -48,7 +51,7 @@ type rackControl struct { // Update will update the RACK related fields when an ACK has been received. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 -func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) { +func (rc *rackControl) Update(seg *segment, ackSeg *segment, offset uint32) { rtt := time.Now().Sub(seg.xmitTime) // If the ACK is for a retransmitted packet, do not update if it is a @@ -65,12 +68,21 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, return } } - if rtt < srtt { + if rtt < rc.minRTT { return } } rc.rtt = rtt + + // The sender can either track a simple global minimum of all RTT + // measurements from the connection, or a windowed min-filtered value + // of recent RTT measurements. This implementation keeps track of the + // simple global minimum of all RTTs for the connection. + if rtt < rc.minRTT || rc.minRTT == 0 { + rc.minRTT = rtt + } + // Update rc.xmitTime and rc.endSequence to the transmit time and // ending sequence number of the packet which has been acknowledged // most recently. diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 4aafb4d22..48bf196d8 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -88,6 +88,13 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { avail := wndFromSpace(r.ep.receiveBufferAvailable()) + if avail == 0 { + // We have no space available to accept any data, move to zero window + // state. + r.rcvWnd = 0 + return r.rcvNxt, 0 + } + acc := r.rcvNxt.Add(seqnum.Size(avail)) newWnd := r.rcvNxt.Size(acc) curWnd := r.rcvNxt.Size(r.rcvAcc) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index c55589c45..4c9a86cda 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1365,9 +1365,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ackLeft := acked originalOutstanding := s.outstanding - s.rtt.Lock() - srtt := s.rtt.srtt - s.rtt.Unlock() for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1389,7 +1386,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Update the RACK fields if SACK is enabled. if s.ep.sackPermitted { - s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset) + s.rc.Update(seg, rcvdSeg, s.ep.tsOffset) } s.writeList.Remove(seg) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index dd810f594..5b504d0d1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5435,8 +5435,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") tests := []struct { name string @@ -6182,7 +6182,9 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) tsVal := uint32(rawEP.TSVal) - rawEP.SendPacketWithTS([]byte{1}, tsVal) + rawEP.NextSeqNum-- + rawEP.SendPacketWithTS(nil, tsVal) + rawEP.NextSeqNum++ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 086d0bdbc..d57ed5d79 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1397,15 +1397,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - // Never receive from a multicast address. - if header.IsV4MulticastAddress(id.RemoteAddress) || - header.IsV6MulticastAddress(id.RemoteAddress) { - e.stack.Stats().UDP.InvalidSourceAddress.Increment() - e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - return - } - if !verifyChecksum(r, hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index e6fc23258..da5b1deb2 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -45,6 +45,7 @@ const ( ) type protocol struct { + stack *stack.Stack } // Number returns the udp protocol number. @@ -53,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid udp packet size. @@ -114,6 +115,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a UDP transport protocol. -func NewProtocol(*stack.Stack) stack.TransportProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0556ef879..cddedb686 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -928,42 +928,6 @@ func TestReadFromMulticast(t *testing.T) { } } -// TestReadFromMulticaststats checks that a discarded packet -// that that was sent with multicast SOURCE address increments -// the correct counters and that a regular packet does not. -func TestReadFromMulticastStats(t *testing.T) { - t.Helper() - for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - c.injectPacket(flow, payload, false) - - var want uint64 = 0 - if flow.isReverseMulticast() { - want = 1 - } - if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { - t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) - } - if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { - t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) - } - }) - } -} - // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -1466,6 +1430,30 @@ func TestNoChecksum(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct{} + +func (*testInterface) ID() tcpip.NICID { + return 0 +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1483,16 +1471,19 @@ func TestTTL(t *testing.T) { if flow.isMulticast() { wantTTL = multicastTTL } else { - var p stack.NetworkProtocol + var p stack.NetworkProtocolFactory + var n tcpip.NetworkProtocolNumber if flow.isV4() { - p = ipv4.NewProtocol(nil) + p = ipv4.NewProtocol + n = ipv4.ProtocolNumber } else { - p = ipv6.NewProtocol(nil) + p = ipv6.NewProtocol + n = ipv6.ProtocolNumber } - ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - })) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{p}, + }) + ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil) wantTTL = ep.DefaultTTL() ep.Close() } @@ -1789,16 +1780,26 @@ func TestV4UnknownDestination(t *testing.T) { checker.ICMPv4Type(header.ICMPv4DstUnreachable), checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + // We need to compare the included data part of the UDP packet that is in + // the ICMP packet with the matching original data. icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) + incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize wantLen := len(payload) if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + // To work out the data size we need to simulate what the sender would + // have done. The wanted size is the total available minus the sum of + // the headers in the UDP AND ICMP packets, given that we know the test + // had only a minimal IP header but the ICMP sender will have allowed + // for a maximally sized packet header. + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + } - // In case of large payloads the IP packet may be truncated. Update + // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + // Add back the two headers within the payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) origDgram := header.UDP(payloadIPHeader.Payload()) if got, want := len(origDgram.Payload()), wantLen; got != want { @@ -2024,7 +2025,8 @@ func TestPayloadModifiedV4(t *testing.T) { payload := newPayload() h := unicastV4.header4Tuple(incoming) buf := c.buildV4Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), @@ -2054,7 +2056,8 @@ func TestPayloadModifiedV6(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 06fb823f6..49ab87c58 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -270,7 +270,7 @@ func RandomID(prefix string) string { // same name, sometimes between test runs the socket does not get cleaned up // quickly enough, causing container creation to fail. func RandomContainerID() string { - return RandomID("test-container-") + return RandomID("test-container") } // Copy copies file from src to dst. diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 27279b409..9b1e7a085 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -21,7 +21,6 @@ import ( "io" "strconv" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/gohacks" "gvisor.dev/gvisor/pkg/safemem" @@ -184,51 +183,6 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) { return n, err } -// CopyObjectOut copies a fixed-size value or slice of fixed-size values from -// src to the memory mapped at addr in uio. It returns the number of bytes -// copied. -// -// CopyObjectOut must use reflection to encode src; performance-sensitive -// clients should do encoding manually and use uio.CopyOut directly. -// -// Preconditions: Same as IO.CopyOut. -func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) { - w := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - // Allocate a byte slice the size of the object being marshaled. This - // adds an extra reflection call, but avoids needing to grow the slice - // during encoding, which can result in many heap-allocated slices. - b := make([]byte, 0, binary.Size(src)) - return w.Write(binary.Marshal(b, ByteOrder, src)) -} - -// CopyObjectIn copies a fixed-size value or slice of fixed-size values from -// the memory mapped at addr in uio to dst. It returns the number of bytes -// copied. -// -// CopyObjectIn must use reflection to decode dst; performance-sensitive -// clients should use uio.CopyIn directly and do decoding manually. -// -// Preconditions: Same as IO.CopyIn. -func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) { - r := &IOReadWriter{ - Ctx: ctx, - IO: uio, - Addr: addr, - Opts: opts, - } - buf := make([]byte, binary.Size(dst)) - if _, err := io.ReadFull(r, buf); err != nil { - return 0, err - } - binary.Unmarshal(buf, ByteOrder, dst) - return int(r.Addr - addr), nil -} - // CopyStringIn tuning parameters, defined outside that function for tests. const ( copyStringIncrement = 64 diff --git a/pkg/usermem/usermem_test.go b/pkg/usermem/usermem_test.go index bf3c5df2b..da60b0cc7 100644 --- a/pkg/usermem/usermem_test.go +++ b/pkg/usermem/usermem_test.go @@ -16,7 +16,6 @@ package usermem import ( "bytes" - "encoding/binary" "fmt" "reflect" "strings" @@ -174,23 +173,6 @@ type testStruct struct { Uint64 uint64 } -func TestCopyObject(t *testing.T) { - wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8} - wantN := binary.Size(wantObj) - b := &BytesIO{make([]byte, wantN)} - ctx := newContext() - if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil { - t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - var gotObj testStruct - if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil { - t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) - } - if gotObj != wantObj { - t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj) - } -} - func TestCopyStringInShort(t *testing.T) { // Tests for string length <= copyStringIncrement. want := strings.Repeat("A", copyStringIncrement-2) diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 6ac19668f..a7c4ebb0c 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -162,6 +162,12 @@ var allowedSyscalls = seccomp.SyscallRules{ }, syscall.SYS_LSEEK: {}, syscall.SYS_MADVISE: {}, + unix.SYS_MEMBARRIER: []seccomp.Rule{ + { + seccomp.EqualTo(linux.MEMBARRIER_CMD_GLOBAL), + seccomp.EqualTo(0), + }, + }, syscall.SYS_MINCORE: {}, // Used by the Go runtime as a temporarily workaround for a Linux // 5.2-5.4 bug. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 2e652ddad..dee2c4fbb 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -282,6 +282,7 @@ func New(args Args) (*Loader, error) { args.NumCPU = runtime.NumCPU() } log.Infof("CPUs: %d", args.NumCPU) + runtime.GOMAXPROCS(args.NumCPU) if args.TotalMem > 0 { // Adjust the total memory returned by the Sentry so that applications that @@ -797,7 +798,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn } // startGoferMonitor runs a goroutine to monitor gofer's health. It polls on -// the gofer FDs looking for disconnects, and destroys the container if a +// the gofer FDs looking for disconnects, and kills the container processes if a // disconnect occurs in any of the gofer FDs. func (l *Loader) startGoferMonitor(cid string, goferFDs []*fd.FD) { go func() { @@ -818,18 +819,15 @@ func (l *Loader) startGoferMonitor(cid string, goferFDs []*fd.FD) { panic(fmt.Sprintf("Error monitoring gofer FDs: %v", err)) } - // Check if the gofer has stopped as part of normal container destruction. - // This is done just to avoid sending an annoying error message to the log. - // Note that there is a small race window in between mu.Unlock() and the - // lock being reacquired in destroyContainer(), but it's harmless to call - // destroyContainer() multiple times. l.mu.Lock() - _, ok := l.processes[execID{cid: cid}] - l.mu.Unlock() - if ok { - log.Infof("Gofer socket disconnected, destroying container %q", cid) - if err := l.destroyContainer(cid); err != nil { - log.Warningf("Error destroying container %q after gofer stopped: %v", cid, err) + defer l.mu.Unlock() + + // The gofer could have been stopped due to a normal container shutdown. + // Check if the container has not stopped yet. + if tg, _ := l.tryThreadGroupFromIDLocked(execID{cid: cid}); tg != nil { + log.Infof("Gofer socket disconnected, killing container %q", cid) + if err := l.signalAllProcesses(cid, int32(linux.SIGKILL)); err != nil { + log.Warningf("Error killing container %q after gofer stopped: %v", cid, err) } } }() @@ -898,17 +896,24 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) { return 0, fmt.Errorf("container %q not started", args.ContainerID) } - // Get the container MountNamespace from the Task. + // Get the container MountNamespace from the Task. Try to acquire ref may fail + // in case it raced with task exit. if kernel.VFS2Enabled { // task.MountNamespace() does not take a ref, so we must do so ourselves. args.MountNamespaceVFS2 = tg.Leader().MountNamespaceVFS2() - args.MountNamespaceVFS2.IncRef() + if !args.MountNamespaceVFS2.TryIncRef() { + return 0, fmt.Errorf("container %q has stopped", args.ContainerID) + } } else { + var reffed bool tg.Leader().WithMuLocked(func(t *kernel.Task) { // task.MountNamespace() does not take a ref, so we must do so ourselves. args.MountNamespace = t.MountNamespace() - args.MountNamespace.IncRef() + reffed = args.MountNamespace.TryIncRef() }) + if !reffed { + return 0, fmt.Errorf("container %q has stopped", args.ContainerID) + } } // Add the HOME environment variable if it is not already set. diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index bf9ec5d38..1f49431a2 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -264,7 +264,7 @@ type CreateMountTestcase struct { expectedPaths []string } -func createMountTestcases(vfs2 bool) []*CreateMountTestcase { +func createMountTestcases() []*CreateMountTestcase { testCases := []*CreateMountTestcase{ &CreateMountTestcase{ // Only proc. @@ -409,32 +409,26 @@ func createMountTestcases(vfs2 bool) []*CreateMountTestcase { Destination: "/proc", Type: "tmpfs", }, - // TODO (gvisor.dev/issue/1487): Re-add this case when sysfs supports - // MkDirAt in VFS2 (and remove the reduntant append). - // { - // Destination: "/sys/bar", - // Type: "tmpfs", - // }, - // + { + Destination: "/sys/bar", + Type: "tmpfs", + }, + { Destination: "/tmp/baz", Type: "tmpfs", }, }, }, - expectedPaths: []string{"/proc", "/sys" /* "/sys/bar" ,*/, "/tmp", "/tmp/baz"}, + expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"}, } - if !vfs2 { - vfsCase.spec.Mounts = append(vfsCase.spec.Mounts, specs.Mount{Destination: "/sys/bar", Type: "tmpfs"}) - vfsCase.expectedPaths = append(vfsCase.expectedPaths, "/sys/bar") - } return append(testCases, vfsCase) } // Test that MountNamespace can be created with various specs. func TestCreateMountNamespace(t *testing.T) { - for _, tc := range createMountTestcases(false /* vfs2 */) { + for _, tc := range createMountTestcases() { t.Run(tc.name, func(t *testing.T) { conf := testConfig() ctx := contexttest.Context(t) @@ -471,7 +465,7 @@ func TestCreateMountNamespace(t *testing.T) { // Test that MountNamespace can be created with various specs. func TestCreateMountNamespaceVFS2(t *testing.T) { - for _, tc := range createMountTestcases(true /* vfs2 */) { + for _, tc := range createMountTestcases() { t.Run(tc.name, func(t *testing.T) { spec := testSpec() spec.Mounts = tc.spec.Mounts diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 548c68087..1f8e277cc 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -316,6 +316,7 @@ func configs(t *testing.T, opts ...configOption) map[string]*config.Config { return cs } +// TODO(gvisor.dev/issue/1624): Merge with configs when VFS2 is the default. func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*config.Config { all := configs(t, opts...) for key, value := range configs(t, opts...) { @@ -894,13 +895,15 @@ func TestKillPid(t *testing.T) { } } -// TestCheckpointRestore creates a container that continuously writes successive integers -// to a file. To test checkpoint and restore functionality, the container is -// checkpointed and the last number printed to the file is recorded. Then, it is restored in two -// new containers and the first number printed from these containers is checked. Both should -// be the next consecutive number after the last number from the checkpointed container. +// TestCheckpointRestore creates a container that continuously writes successive +// integers to a file. To test checkpoint and restore functionality, the +// container is checkpointed and the last number printed to the file is +// recorded. Then, it is restored in two new containers and the first number +// printed from these containers is checked. Both should be the next consecutive +// number after the last number from the checkpointed container. func TestCheckpointRestore(t *testing.T) { // Skip overlay because test requires writing to host file. + // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test") @@ -1062,6 +1065,7 @@ func TestCheckpointRestore(t *testing.T) { // with filesystem Unix Domain Socket use. func TestUnixDomainSockets(t *testing.T) { // Skip overlay because test requires writing to host file. + // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { // UDS path is limited to 108 chars for compatibility with older systems. @@ -1199,7 +1203,7 @@ func TestUnixDomainSockets(t *testing.T) { // recreated. Then it resumes the container, verify that the file gets created // again. func TestPauseResume(t *testing.T) { - for name, conf := range configs(t, noOverlay...) { + for name, conf := range configsWithVFS2(t, noOverlay...) { t.Run(name, func(t *testing.T) { tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock") if err != nil { diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 952215ec1..850e80290 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -480,7 +480,7 @@ func TestMultiContainerMount(t *testing.T) { // TestMultiContainerSignal checks that it is possible to signal individual // containers without killing the entire sandbox. func TestMultiContainerSignal(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1691,12 +1691,11 @@ func TestMultiContainerRunNonRoot(t *testing.T) { } // TestMultiContainerHomeEnvDir tests that the HOME environment variable is set -// for root containers, sub-containers, and execed processes. +// for root containers, sub-containers, and exec'ed processes. func TestMultiContainerHomeEnvDir(t *testing.T) { - // TODO(gvisor.dev/issue/1487): VFSv2 configs failing. // NOTE: Don't use overlay since we need changes to persist to the temp dir // outside the sandbox. - for testName, conf := range configs(t, noOverlay...) { + for testName, conf := range configsWithVFS2(t, noOverlay...) { t.Run(testName, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() @@ -1718,9 +1717,9 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { // We will sleep in the root container in order to ensure that the root //container doesn't terminate before sub containers can be created. - rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s; sleep 1000", homeDirs["root"].Name())} - subCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["sub"].Name())} - execCmd := fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name()) + rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf(`printf "$HOME" > %s; sleep 1000`, homeDirs["root"].Name())} + subCmd := []string{"/bin/sh", "-c", fmt.Sprintf(`printf "$HOME" > %s`, homeDirs["sub"].Name())} + execCmd := fmt.Sprintf(`printf "$HOME" > %s`, homeDirs["exec"].Name()) // Setup the containers, a root container and sub container. specConfig, ids := createSpecs(rootCmd, subCmd) diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 0b9f39466..8f66dd1f8 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -309,11 +309,20 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( const bufSize = 4 << 20 // 4MB. if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil { - return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", bufSize, err) + syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize) + sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF) + + if sz < bufSize { + log.Warningf("Failed to increase rcv buffer to %d on SOCK_RAW on %s. Current buffer %d: %v", bufSize, iface.Name, sz, err) + } } if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil { - return nil, fmt.Errorf("failed to increase socket snd buffer to %d: %v", bufSize, err) + syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize) + sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if sz < bufSize { + log.Warningf("Failed to increase snd buffer to %d on SOCK_RAW on %s. Curent buffer %d: %v", bufSize, iface.Name, sz, err) + } } return &socketEntry{deviceFile, gsoMaxSize}, nil diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index a8f4f64a5..c4309feb3 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -72,11 +72,14 @@ type Sandbox struct { // will have it as a child process. child bool - // status is an exit status of a sandbox process. - status syscall.WaitStatus - // statusMu protects status. statusMu sync.Mutex + + // status is the exit status of a sandbox process. It's only set if the + // child==true and the sandbox was waited on. This field allows for multiple + // threads to wait on sandbox and get the exit code, since Linux will return + // WaitStatus to one of the waiters only. + status syscall.WaitStatus } // Args is used to configure a new sandbox. @@ -746,35 +749,47 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn // Wait waits for the containerized process to exit, and returns its WaitStatus. func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { log.Debugf("Waiting for container %q in sandbox %q", cid, s.ID) - var ws syscall.WaitStatus if conn, err := s.sandboxConnect(); err != nil { - // The sandbox may have exited while before we had a chance to - // wait on it. + // The sandbox may have exited while before we had a chance to wait on it. + // There is nothing we can do for subcontainers. For the init container, we + // can try to get the sandbox exit code. + if !s.IsRootContainer(cid) { + return syscall.WaitStatus(0), err + } log.Warningf("Wait on container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err) } else { defer conn.Close() + // Try the Wait RPC to the sandbox. + var ws syscall.WaitStatus err = conn.Call(boot.ContainerWait, &cid, &ws) if err == nil { // It worked! return ws, nil } + // See comment above. + if !s.IsRootContainer(cid) { + return syscall.WaitStatus(0), err + } + // The sandbox may have exited after we connected, but before // or during the Wait RPC. log.Warningf("Wait RPC to container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err) } - // The sandbox may have already exited, or exited while handling the - // Wait RPC. The best we can do is ask Linux what the sandbox exit - // status was, since in most cases that will be the same as the - // container exit status. + // The sandbox may have already exited, or exited while handling the Wait RPC. + // The best we can do is ask Linux what the sandbox exit status was, since in + // most cases that will be the same as the container exit status. if err := s.waitForStopped(); err != nil { - return ws, err + return syscall.WaitStatus(0), err } if !s.child { - return ws, fmt.Errorf("sandbox no longer running and its exit status is unavailable") + return syscall.WaitStatus(0), fmt.Errorf("sandbox no longer running and its exit status is unavailable") } + + s.statusMu.Lock() + defer s.statusMu.Unlock() return s.status, nil } diff --git a/test/iptables/README.md b/test/iptables/README.md index b9f44bd40..28ab195ca 100644 --- a/test/iptables/README.md +++ b/test/iptables/README.md @@ -1,6 +1,6 @@ # iptables Tests -iptables tests are run via `scripts/iptables_test.sh`. +iptables tests are run via `make iptables-tests`. iptables requires raw socket support, so you must add the `--net-raw=true` flag to `/etc/docker/daemon.json` in order to use it. diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 398f70ecd..834f7615f 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -48,13 +48,6 @@ func singleTest(t *testing.T, test TestCase) { } } -// TODO(gvisor.dev/issue/3549): IPv6 NAT support. -func ipv4Test(t *testing.T, test TestCase) { - t.Run("IPv4", func(t *testing.T) { - iptablesTest(t, test, false) - }) -} - func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { if _, ok := Tests[test.Name()]; !ok { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) @@ -325,66 +318,66 @@ func TestFilterOutputInvertDestination(t *testing.T) { } func TestNATPreRedirectUDPPort(t *testing.T) { - ipv4Test(t, NATPreRedirectUDPPort{}) + singleTest(t, NATPreRedirectUDPPort{}) } func TestNATPreRedirectTCPPort(t *testing.T) { - ipv4Test(t, NATPreRedirectTCPPort{}) + singleTest(t, NATPreRedirectTCPPort{}) } func TestNATPreRedirectTCPOutgoing(t *testing.T) { - ipv4Test(t, NATPreRedirectTCPOutgoing{}) + singleTest(t, NATPreRedirectTCPOutgoing{}) } func TestNATOutRedirectTCPIncoming(t *testing.T) { - ipv4Test(t, NATOutRedirectTCPIncoming{}) + singleTest(t, NATOutRedirectTCPIncoming{}) } func TestNATOutRedirectUDPPort(t *testing.T) { - ipv4Test(t, NATOutRedirectUDPPort{}) + singleTest(t, NATOutRedirectUDPPort{}) } func TestNATOutRedirectTCPPort(t *testing.T) { - ipv4Test(t, NATOutRedirectTCPPort{}) + singleTest(t, NATOutRedirectTCPPort{}) } func TestNATDropUDP(t *testing.T) { - ipv4Test(t, NATDropUDP{}) + singleTest(t, NATDropUDP{}) } func TestNATAcceptAll(t *testing.T) { - ipv4Test(t, NATAcceptAll{}) + singleTest(t, NATAcceptAll{}) } func TestNATOutRedirectIP(t *testing.T) { - ipv4Test(t, NATOutRedirectIP{}) + singleTest(t, NATOutRedirectIP{}) } func TestNATOutDontRedirectIP(t *testing.T) { - ipv4Test(t, NATOutDontRedirectIP{}) + singleTest(t, NATOutDontRedirectIP{}) } func TestNATOutRedirectInvert(t *testing.T) { - ipv4Test(t, NATOutRedirectInvert{}) + singleTest(t, NATOutRedirectInvert{}) } func TestNATPreRedirectIP(t *testing.T) { - ipv4Test(t, NATPreRedirectIP{}) + singleTest(t, NATPreRedirectIP{}) } func TestNATPreDontRedirectIP(t *testing.T) { - ipv4Test(t, NATPreDontRedirectIP{}) + singleTest(t, NATPreDontRedirectIP{}) } func TestNATPreRedirectInvert(t *testing.T) { - ipv4Test(t, NATPreRedirectInvert{}) + singleTest(t, NATPreRedirectInvert{}) } func TestNATRedirectRequiresProtocol(t *testing.T) { - ipv4Test(t, NATRedirectRequiresProtocol{}) + singleTest(t, NATRedirectRequiresProtocol{}) } func TestNATLoopbackSkipsPrerouting(t *testing.T) { - ipv4Test(t, NATLoopbackSkipsPrerouting{}) + singleTest(t, NATLoopbackSkipsPrerouting{}) } func TestInputSource(t *testing.T) { @@ -421,9 +414,9 @@ func TestFilterAddrs(t *testing.T) { } func TestNATPreOriginalDst(t *testing.T) { - ipv4Test(t, NATPreOriginalDst{}) + singleTest(t, NATPreOriginalDst{}) } func TestNATOutOriginalDst(t *testing.T) { - ipv4Test(t, NATOutOriginalDst{}) + singleTest(t, NATOutOriginalDst{}) } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 94731c64b..11db49e39 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -310,8 +310,6 @@ packetimpact_go_test( packetimpact_go_test( name = "ipv6_fragment_reassembly", srcs = ["ipv6_fragment_reassembly_test.go"], - # TODO(b/160919104): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 032ebd04e..248053dc3 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -203,7 +203,6 @@ def syscall_test( tags = platform_tags + tags, ) - # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests. if add_overlay: _syscall_test( test = test, @@ -216,6 +215,23 @@ def syscall_test( overlay = True, ) + # TODO(gvisor.dev/issue/4407): Remove tags to enable VFS2 overlay tests. + overlay_vfs2_tags = list(vfs2_tags) + overlay_vfs2_tags.append("manual") + overlay_vfs2_tags.append("noguitar") + overlay_vfs2_tags.append("notap") + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + overlay_vfs2_tags, + overlay = True, + vfs2 = True, + ) + if add_hostinet: _syscall_test( test = test, diff --git a/test/runtimes/exclude/java11.csv b/test/runtimes/exclude/java11.csv index f779df8d5..d978baca7 100644 --- a/test/runtimes/exclude/java11.csv +++ b/test/runtimes/exclude/java11.csv @@ -57,6 +57,7 @@ java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, +java/security/cert/PolicyNode/GetPolicyQualifiers.java,b/170263154,Kokoro executor cert expired java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker java/util/Calendar/JapaneseEraNameTest.java,, diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv index 749fb9482..ba993814f 100644 --- a/test/runtimes/exclude/nodejs12.4.0.csv +++ b/test/runtimes/exclude/nodejs12.4.0.csv @@ -13,7 +13,7 @@ parallel/test-dns-channel-timeout.js,b/161893056, parallel/test-fs-access.js,, parallel/test-fs-watchfile.js,,Flaky - File already exists error parallel/test-fs-write-stream.js,b/166819807,Flaky -parallel/test-fs-write-stream-double-close,b/166819807,Flaky +parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 parallel/test-os.js,b/63997097, diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD index d1935cbe8..fdc6d3173 100644 --- a/test/runtimes/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -1,29 +1,11 @@ -load("//tools:defs.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) go_binary( name = "proctor", - srcs = [ - "go.go", - "java.go", - "nodejs.go", - "php.go", - "proctor.go", - "python.go", - ], + srcs = ["main.go"], pure = True, visibility = ["//test/runtimes:__pkg__"], -) - -go_test( - name = "proctor_test", - size = "small", - srcs = ["proctor_test.go"], - library = ":proctor", - nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems. - pure = True, - deps = [ - "//pkg/test/testutil", - ], + deps = ["//test/runtimes/proctor/lib"], ) diff --git a/test/runtimes/proctor/lib/BUILD b/test/runtimes/proctor/lib/BUILD new file mode 100644 index 000000000..0c8367dfe --- /dev/null +++ b/test/runtimes/proctor/lib/BUILD @@ -0,0 +1,24 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "lib", + srcs = [ + "go.go", + "java.go", + "lib.go", + "nodejs.go", + "php.go", + "python.go", + ], + visibility = ["//test/runtimes/proctor:__pkg__"], +) + +go_test( + name = "lib_test", + size = "small", + srcs = ["lib_test.go"], + library = ":lib", + deps = ["//pkg/test/testutil"], +) diff --git a/test/runtimes/proctor/go.go b/test/runtimes/proctor/lib/go.go index d0ae844e6..5c48fb60b 100644 --- a/test/runtimes/proctor/go.go +++ b/test/runtimes/proctor/lib/go.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" @@ -59,7 +59,7 @@ func (goRunner) ListTests() ([]string, error) { } // Go tests on disk. - diskSlice, err := search(goTestDir, goTestRegEx) + diskSlice, err := Search(goTestDir, goTestRegEx) if err != nil { return nil, err } diff --git a/test/runtimes/proctor/java.go b/test/runtimes/proctor/lib/java.go index d456fa681..3105011ff 100644 --- a/test/runtimes/proctor/java.go +++ b/test/runtimes/proctor/lib/java.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" diff --git a/test/runtimes/proctor/proctor.go b/test/runtimes/proctor/lib/lib.go index 9e0642424..f2ba82498 100644 --- a/test/runtimes/proctor/proctor.go +++ b/test/runtimes/proctor/lib/lib.go @@ -12,20 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary proctor runs the test for a particular runtime. It is meant to be -// included in Docker images for all runtime tests. -package main +// Package lib contains proctor functions. +package lib import ( - "flag" "fmt" - "log" "os" "os/exec" "os/signal" "path/filepath" "regexp" - "strings" "syscall" ) @@ -42,66 +38,8 @@ type TestRunner interface { TestCmds(tests []string) []*exec.Cmd } -var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - testNames = flag.String("tests", "", "run a subset of the available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") -) - -func main() { - flag.Parse() - - if *pause { - pauseAndReap() - panic("pauseAndReap should never return") - } - - if *runtime == "" { - log.Fatalf("runtime flag must be provided") - } - - tr, err := testRunnerForRuntime(*runtime) - if err != nil { - log.Fatalf("%v", err) - } - - // List tests. - if *list { - tests, err := tr.ListTests() - if err != nil { - log.Fatalf("failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - - var tests []string - if *testNames == "" { - // Run every test. - tests, err = tr.ListTests() - if err != nil { - log.Fatalf("failed to get all tests: %v", err) - } - } else { - // Run subset of test. - tests = strings.Split(*testNames, ",") - } - - // Run tests. - cmds := tr.TestCmds(tests) - for _, cmd := range cmds { - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - log.Fatalf("FAIL: %v", err) - } - } -} - -// testRunnerForRuntime returns a new TestRunner for the given runtime. -func testRunnerForRuntime(runtime string) (TestRunner, error) { +// TestRunnerForRuntime returns a new TestRunner for the given runtime. +func TestRunnerForRuntime(runtime string) (TestRunner, error) { switch runtime { case "go": return goRunner{}, nil @@ -117,8 +55,8 @@ func testRunnerForRuntime(runtime string) (TestRunner, error) { return nil, fmt.Errorf("invalid runtime %q", runtime) } -// pauseAndReap is like init. It runs forever and reaps any children. -func pauseAndReap() { +// PauseAndReap is like init. It runs forever and reaps any children. +func PauseAndReap() { // Get notified of any new children. ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGCHLD) @@ -138,9 +76,9 @@ func pauseAndReap() { } } -// search is a helper function to find tests in the given directory that match +// Search is a helper function to find tests in the given directory that match // the regex. -func search(root string, testFilter *regexp.Regexp) ([]string, error) { +func Search(root string, testFilter *regexp.Regexp) ([]string, error) { var testSlice []string err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { diff --git a/test/runtimes/proctor/proctor_test.go b/test/runtimes/proctor/lib/lib_test.go index 6ef2de085..1193d2e28 100644 --- a/test/runtimes/proctor/proctor_test.go +++ b/test/runtimes/proctor/lib/lib_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "io/ioutil" @@ -47,7 +47,7 @@ func TestSearchEmptyDir(t *testing.T) { var want []string testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) + got, err := Search(td, testFilter) if err != nil { t.Errorf("search error: %v", err) } @@ -116,7 +116,7 @@ func TestSearch(t *testing.T) { } testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) + got, err := Search(td, testFilter) if err != nil { t.Errorf("search error: %v", err) } diff --git a/test/runtimes/proctor/nodejs.go b/test/runtimes/proctor/lib/nodejs.go index dead5af4f..320597aa5 100644 --- a/test/runtimes/proctor/nodejs.go +++ b/test/runtimes/proctor/lib/nodejs.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "os/exec" @@ -32,7 +32,7 @@ var _ TestRunner = nodejsRunner{} // ListTests implements TestRunner.ListTests. func (nodejsRunner) ListTests() ([]string, error) { - testSlice, err := search(nodejsTestDir, nodejsTestRegEx) + testSlice, err := Search(nodejsTestDir, nodejsTestRegEx) if err != nil { return nil, err } diff --git a/test/runtimes/proctor/php.go b/test/runtimes/proctor/lib/php.go index 6a83d64e3..b67a60a97 100644 --- a/test/runtimes/proctor/php.go +++ b/test/runtimes/proctor/lib/php.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "os/exec" @@ -29,7 +29,7 @@ var _ TestRunner = phpRunner{} // ListTests implements TestRunner.ListTests. func (phpRunner) ListTests() ([]string, error) { - testSlice, err := search(".", phpTestRegEx) + testSlice, err := Search(".", phpTestRegEx) if err != nil { return nil, err } diff --git a/test/runtimes/proctor/python.go b/test/runtimes/proctor/lib/python.go index 7c598801b..429bfd850 100644 --- a/test/runtimes/proctor/python.go +++ b/test/runtimes/proctor/lib/python.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go new file mode 100644 index 000000000..e5607ac92 --- /dev/null +++ b/test/runtimes/proctor/main.go @@ -0,0 +1,85 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary proctor runs the test for a particular runtime. It is meant to be +// included in Docker images for all runtime tests. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "gvisor.dev/gvisor/test/runtimes/proctor/lib" +) + +var ( + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + testNames = flag.String("tests", "", "run a subset of the available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") +) + +func main() { + flag.Parse() + + if *pause { + lib.PauseAndReap() + panic("pauseAndReap should never return") + } + + if *runtime == "" { + log.Fatalf("runtime flag must be provided") + } + + tr, err := lib.TestRunnerForRuntime(*runtime) + if err != nil { + log.Fatalf("%v", err) + } + + // List tests. + if *list { + tests, err := tr.ListTests() + if err != nil { + log.Fatalf("failed to list tests: %v", err) + } + for _, test := range tests { + fmt.Println(test) + } + return + } + + var tests []string + if *testNames == "" { + // Run every test. + tests, err = tr.ListTests() + if err != nil { + log.Fatalf("failed to get all tests: %v", err) + } + } else { + // Run subset of test. + tests = strings.Split(*testNames, ",") + } + + // Run tests. + cmds := tr.TestCmds(tests) + for _, cmd := range cmds { + cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("FAIL: %v", err) + } + } +} diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD index dc0d5d5b4..70cc01594 100644 --- a/test/runtimes/runner/BUILD +++ b/test/runtimes/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -7,16 +7,5 @@ go_binary( testonly = 1, srcs = ["main.go"], visibility = ["//test/runtimes:__pkg__"], - deps = [ - "//pkg/log", - "//pkg/test/dockerutil", - "//pkg/test/testutil", - ], -) - -go_test( - name = "exclude_test", - size = "small", - srcs = ["exclude_test.go"], - library = ":runner", + deps = ["//test/runtimes/runner/lib"], ) diff --git a/test/runtimes/runner/lib/BUILD b/test/runtimes/runner/lib/BUILD new file mode 100644 index 000000000..d308f41b0 --- /dev/null +++ b/test/runtimes/runner/lib/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "lib", + testonly = 1, + srcs = ["lib.go"], + visibility = ["//test/runtimes/runner:__pkg__"], + deps = [ + "//pkg/log", + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) + +go_test( + name = "lib_test", + size = "small", + srcs = ["exclude_test.go"], + library = ":lib", +) diff --git a/test/runtimes/runner/exclude_test.go b/test/runtimes/runner/lib/exclude_test.go index 67c2170c8..f996e895b 100644 --- a/test/runtimes/runner/exclude_test.go +++ b/test/runtimes/runner/lib/exclude_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "flag" @@ -20,6 +20,8 @@ import ( "testing" ) +var excludeFile = flag.String("exclude_file", "", "file to test (standard format)") + func TestMain(m *testing.M) { flag.Parse() os.Exit(m.Run()) @@ -27,7 +29,7 @@ func TestMain(m *testing.M) { // Test that the exclude file parses without error. func TestExcludelist(t *testing.T) { - ex, err := getExcludes() + ex, err := getExcludes(*excludeFile) if err != nil { t.Fatalf("error parsing exclude file: %v", err) } diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go new file mode 100644 index 000000000..78285cb0e --- /dev/null +++ b/test/runtimes/runner/lib/lib.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lib provides utilities for runner. +package lib + +import ( + "context" + "encoding/csv" + "fmt" + "io" + "os" + "sort" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// RunTests is a helper that is called by main. It exists so that we can run +// defered functions before exiting. It returns an exit code that should be +// passed to os.Exit. +func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int { + // Get tests to exclude.. + excludes, err := getExcludes(excludeFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) + return 1 + } + + // Construct the shared docker instance. + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(lang)) + defer d.CleanUp(ctx) + + if err := testutil.TouchShardStatusFile(); err != nil { + fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err) + return 1 + } + + // Get a slice of tests to run. This will also start a single Docker + // container that will be used to run each test. The final test will + // stop the Docker container. + tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return 1 + } + + m := testing.MainStart(testDeps{}, tests, nil, nil) + return m.Run() +} + +// getTests executes all tests as table tests. +func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) { + // Start the container. + opts := dockerutil.RunOpts{ + Image: fmt.Sprintf("runtimes/%s", image), + } + d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") + if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil { + return nil, fmt.Errorf("docker run failed: %v", err) + } + + // Get a list of all tests in the image. + list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--list") + if err != nil { + return nil, fmt.Errorf("docker exec failed: %v", err) + } + + // Calculate a subset of tests to run corresponding to the current + // shard. + tests := strings.Fields(list) + sort.Strings(tests) + indices, err := testutil.TestIndicesForShard(len(tests)) + if err != nil { + return nil, fmt.Errorf("TestsForShard() failed: %v", err) + } + + var itests []testing.InternalTest + for i := 0; i < len(indices); i += batchSize { + var tcs []string + end := i + batchSize + if end > len(indices) { + end = len(indices) + } + for _, tc := range indices[i:end] { + // Add test if not excluded. + if _, ok := excludes[tests[tc]]; ok { + log.Infof("Skipping test case %s\n", tests[tc]) + continue + } + tcs = append(tcs, tests[tc]) + } + itests = append(itests, testing.InternalTest{ + Name: strings.Join(tcs, ", "), + F: func(t *testing.T) { + var ( + now = time.Now() + done = make(chan struct{}) + output string + err error + ) + + go func() { + fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--tests", strings.Join(tcs, ",")) + close(done) + }() + + select { + case <-done: + if err == nil { + fmt.Printf("PASS: (%v)\n\n", time.Since(now)) + return + } + t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) + case <-time.After(timeout): + t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) + } + }, + }) + } + + return itests, nil +} + +// getBlacklist reads the exclude file and returns a set of test names to +// exclude. +func getExcludes(excludeFile string) (map[string]struct{}, error) { + excludes := make(map[string]struct{}) + if excludeFile == "" { + return excludes, nil + } + f, err := os.Open(excludeFile) + if err != nil { + return nil, err + } + defer f.Close() + + r := csv.NewReader(f) + + // First line is header. Skip it. + if _, err := r.Read(); err != nil { + return nil, err + } + + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + excludes[record[0]] = struct{}{} + } + return excludes, nil +} + +// testDeps implements testing.testDeps (an unexported interface), and is +// required to use testing.MainStart. +type testDeps struct{} + +func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } +func (f testDeps) StartCPUProfile(io.Writer) error { return nil } +func (f testDeps) StopCPUProfile() {} +func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } +func (f testDeps) ImportPath() string { return "" } +func (f testDeps) StartTestLog(io.Writer) {} +func (f testDeps) StopTestLog() error { return nil } diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go index 948e7cf9c..ec79a22c2 100644 --- a/test/runtimes/runner/main.go +++ b/test/runtimes/runner/main.go @@ -16,20 +16,12 @@ package main import ( - "context" - "encoding/csv" "flag" "fmt" - "io" "os" - "sort" - "strings" - "testing" "time" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/test/runtimes/runner/lib" ) var ( @@ -37,169 +29,14 @@ var ( image = flag.String("image", "", "docker image with runtime tests") excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") batchSize = flag.Int("batch", 50, "number of test cases run in one command") + timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout") ) -// Wait time for each test to run. -const timeout = 90 * time.Minute - func main() { flag.Parse() if *lang == "" || *image == "" { fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") os.Exit(1) } - os.Exit(runTests()) -} - -// runTests is a helper that is called by main. It exists so that we can run -// defered functions before exiting. It returns an exit code that should be -// passed to os.Exit. -func runTests() int { - // Get tests to exclude.. - excludes, err := getExcludes() - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) - return 1 - } - - // Construct the shared docker instance. - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(*lang)) - defer d.CleanUp(ctx) - - if err := testutil.TouchShardStatusFile(); err != nil { - fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err) - return 1 - } - - // Get a slice of tests to run. This will also start a single Docker - // container that will be used to run each test. The final test will - // stop the Docker container. - tests, err := getTests(ctx, d, excludes) - if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err.Error()) - return 1 - } - - m := testing.MainStart(testDeps{}, tests, nil, nil) - return m.Run() -} - -// getTests executes all tests as table tests. -func getTests(ctx context.Context, d *dockerutil.Container, excludes map[string]struct{}) ([]testing.InternalTest, error) { - // Start the container. - opts := dockerutil.RunOpts{ - Image: fmt.Sprintf("runtimes/%s", *image), - } - d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") - if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil { - return nil, fmt.Errorf("docker run failed: %v", err) - } - - // Get a list of all tests in the image. - list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--list") - if err != nil { - return nil, fmt.Errorf("docker exec failed: %v", err) - } - - // Calculate a subset of tests to run corresponding to the current - // shard. - tests := strings.Fields(list) - sort.Strings(tests) - indices, err := testutil.TestIndicesForShard(len(tests)) - if err != nil { - return nil, fmt.Errorf("TestsForShard() failed: %v", err) - } - - var itests []testing.InternalTest - for i := 0; i < len(indices); i += *batchSize { - var tcs []string - end := i + *batchSize - if end > len(indices) { - end = len(indices) - } - for _, tc := range indices[i:end] { - // Add test if not excluded. - if _, ok := excludes[tests[tc]]; ok { - log.Infof("Skipping test case %s\n", tests[tc]) - continue - } - tcs = append(tcs, tests[tc]) - } - itests = append(itests, testing.InternalTest{ - Name: strings.Join(tcs, ", "), - F: func(t *testing.T) { - var ( - now = time.Now() - done = make(chan struct{}) - output string - err error - ) - - go func() { - fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) - output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--tests", strings.Join(tcs, ",")) - close(done) - }() - - select { - case <-done: - if err == nil { - fmt.Printf("PASS: (%v)\n\n", time.Since(now)) - return - } - t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) - case <-time.After(timeout): - t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) - } - }, - }) - } - - return itests, nil + os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout)) } - -// getBlacklist reads the exclude file and returns a set of test names to -// exclude. -func getExcludes() (map[string]struct{}, error) { - excludes := make(map[string]struct{}) - if *excludeFile == "" { - return excludes, nil - } - f, err := os.Open(*excludeFile) - if err != nil { - return nil, err - } - defer f.Close() - - r := csv.NewReader(f) - - // First line is header. Skip it. - if _, err := r.Read(); err != nil { - return nil, err - } - - for { - record, err := r.Read() - if err == io.EOF { - break - } - if err != nil { - return nil, err - } - excludes[record[0]] = struct{}{} - } - return excludes, nil -} - -// testDeps implements testing.testDeps (an unexported interface), and is -// required to use testing.MainStart. -type testDeps struct{} - -func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } -func (f testDeps) StartCPUProfile(io.Writer) error { return nil } -func (f testDeps) StopCPUProfile() {} -func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } -func (f testDeps) ImportPath() string { return "" } -func (f testDeps) StartTestLog(io.Writer) {} -func (f testDeps) StopTestLog() error { return nil } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 96a775456..f66a9ceb4 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -286,6 +286,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:membarrier_test", +) + +syscall_test( test = "//test/syscalls/linux:memory_accounting_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index c775a6d75..36b7f1b97 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1079,6 +1079,7 @@ cc_binary( gtest, "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", ], ) @@ -1155,6 +1156,24 @@ cc_binary( ) cc_binary( + name = "membarrier_test", + testonly = 1, + srcs = ["membarrier.cc"], + linkstatic = 1, + deps = [ + "@com_google_absl//absl/time", + gtest, + "//test/util:cleanup", + "//test/util:logging", + "//test/util:memory_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( name = "mempolicy_test", testonly = 1, srcs = ["mempolicy.cc"], @@ -1667,6 +1686,7 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/test/syscalls/linux/ip6tables.cc b/test/syscalls/linux/ip6tables.cc index 97297ee2b..de0a1c114 100644 --- a/test/syscalls/linux/ip6tables.cc +++ b/test/syscalls/linux/ip6tables.cc @@ -82,6 +82,32 @@ TEST(IP6TablesBasic, GetEntriesErrorPrecedence) { SyscallFailsWithErrno(EINVAL)); } +TEST(IP6TablesBasic, GetRevision) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW), + SyscallSucceeds()); + + struct xt_get_revision rev = { + .name = "REDIRECT", + .revision = 0, + }; + socklen_t rev_len = sizeof(rev); + + // Revision 0 exists. + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallSucceeds()); + EXPECT_EQ(rev.revision, 0); + + // Revisions > 0 don't exist. + rev.revision = 1; + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallFailsWithErrno(EPROTONOSUPPORT)); +} + // This tests the initial state of a machine with empty ip6tables via // getsockopt(IP6T_SO_GET_INFO). We don't have a guarantee that the iptables are // empty when running in native, but we can test that gVisor has the same diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc index 83b6a164a..7ee10bbde 100644 --- a/test/syscalls/linux/iptables.cc +++ b/test/syscalls/linux/iptables.cc @@ -117,6 +117,32 @@ TEST(IPTablesBasic, OriginalDstErrors) { SyscallFailsWithErrno(ENOTCONN)); } +TEST(IPTablesBasic, GetRevision) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallSucceeds()); + + struct xt_get_revision rev = { + .name = "REDIRECT", + .revision = 0, + }; + socklen_t rev_len = sizeof(rev); + + // Revision 0 exists. + EXPECT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallSucceeds()); + EXPECT_EQ(rev.revision, 0); + + // Revisions > 0 don't exist. + rev.revision = 1; + EXPECT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallFailsWithErrno(EPROTONOSUPPORT)); +} + // Fixture for iptables tests. class IPTablesTest : public ::testing::Test { protected: diff --git a/test/syscalls/linux/kcov.cc b/test/syscalls/linux/kcov.cc index 6afcb4e75..6816c1fd0 100644 --- a/test/syscalls/linux/kcov.cc +++ b/test/syscalls/linux/kcov.cc @@ -16,39 +16,47 @@ #include <sys/ioctl.h> #include <sys/mman.h> +#include <atomic> + #include "gtest/gtest.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { namespace { -// For this test to work properly, it must be run with coverage enabled. On +// For this set of tests to run, they must be run with coverage enabled. On // native Linux, this involves compiling the kernel with kcov enabled. For -// gVisor, we need to enable the Go coverage tool, e.g. -// bazel test --collect_coverage_data --instrumentation_filter=//pkg/... <test>. +// gVisor, we need to enable the Go coverage tool, e.g. bazel test -- +// collect_coverage_data --instrumentation_filter=//pkg/... <test>. + +constexpr char kcovPath[] = "/sys/kernel/debug/kcov"; +constexpr int kSize = 4096; +constexpr int KCOV_INIT_TRACE = 0x80086301; +constexpr int KCOV_ENABLE = 0x6364; +constexpr int KCOV_DISABLE = 0x6365; + +uint64_t* KcovMmap(int fd) { + return (uint64_t*)mmap(nullptr, kSize * sizeof(uint64_t), + PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); +} + TEST(KcovTest, Kcov) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); - constexpr int kSize = 4096; - constexpr int KCOV_INIT_TRACE = 0x80086301; - constexpr int KCOV_ENABLE = 0x6364; - constexpr int KCOV_DISABLE = 0x6365; - int fd; - ASSERT_THAT(fd = open("/sys/kernel/debug/kcov", O_RDWR), + ASSERT_THAT(fd = open(kcovPath, O_RDWR), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); - // Kcov not available. SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); - uint64_t* area = (uint64_t*)mmap(nullptr, kSize * sizeof(uint64_t), - PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - + uint64_t* area = KcovMmap(fd); ASSERT_TRUE(area != MAP_FAILED); ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); @@ -67,6 +75,109 @@ TEST(KcovTest, Kcov) { ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds()); } +TEST(KcovTest, PrematureMmap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd; + ASSERT_THAT(fd = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); + + // Cannot mmap before KCOV_INIT_TRACE. + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area == MAP_FAILED); +} + +// Tests that multiple kcov fds can be used simultaneously. +TEST(KcovTest, MultipleFds) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd1; + ASSERT_THAT(fd1 = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + + int fd2; + ASSERT_THAT(fd2 = open(kcovPath, O_RDWR), SyscallSucceeds()); + auto fd_closer = Cleanup([fd1, fd2]() { + close(fd1); + close(fd2); + }); + + auto t1 = ScopedThread([&] { + ASSERT_THAT(ioctl(fd1, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd1); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd1, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + ASSERT_THAT(ioctl(fd2, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd2); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd2, KCOV_ENABLE, 0), SyscallSucceeds()); +} + +// Tests behavior for two threads trying to use the same kcov fd. +TEST(KcovTest, MultipleThreads) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd; + ASSERT_THAT(fd = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); + + // Test the behavior of multiple threads trying to use the same kcov fd + // simultaneously. + std::atomic<bool> t1_enabled(false), t1_disabled(false), t2_failed(false), + t2_exited(false); + auto t1 = ScopedThread([&] { + ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + t1_enabled = true; + + // After t2 has made sure that enabling kcov again fails, disable it. + while (!t2_failed) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds()); + t1_disabled = true; + + // Wait for t2 to enable kcov and then exit, after which we should be able + // to enable kcov again, without needing to set up a new memory mapping. + while (!t2_exited) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + auto t2 = ScopedThread([&] { + // Wait for t1 to enable kcov, and make sure that enabling kcov again fails. + while (!t1_enabled) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallFailsWithErrno(EINVAL)); + t2_failed = true; + + // Wait for t1 to disable kcov, after which using fd should now succeed. + while (!t1_disabled) { + sched_yield(); + } + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + t2.Join(); + t2_exited = true; +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/membarrier.cc b/test/syscalls/linux/membarrier.cc new file mode 100644 index 000000000..516956a25 --- /dev/null +++ b/test/syscalls/linux/membarrier.cc @@ -0,0 +1,268 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <signal.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <unistd.h> + +#include <atomic> + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/util/cleanup.h" +#include "test/util/logging.h" +#include "test/util/memory_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// This is the classic test case for memory fences on architectures with total +// store ordering; see e.g. Intel SDM Vol. 3A Sec. 8.2.3.4 "Loads May Be +// Reordered with Earlier Stores to Different Locations". In each iteration of +// the test, given two variables X and Y initially set to 0 +// (MembarrierTestSharedState::local_var and remote_var in the code), two +// threads execute as follows: +// +// T1 T2 +// -- -- +// +// X = 1 Y = 1 +// T1fence() T2fence() +// read Y read X +// +// On architectures where memory writes may be locally buffered by each CPU +// (essentially all architectures), if T1fence() and T2fence() are omitted or +// ineffective, it is possible for both T1 and T2 to read 0 because the memory +// write from the other CPU is not yet visible outside that CPU. T1fence() and +// T2fence() are expected to perform the necessary synchronization to restore +// sequential consistency: both threads agree on a order of memory accesses that +// is consistent with program order in each thread, such that at least one +// thread reads 1. +// +// In the NoMembarrier test, T1fence() and T2fence() are both ordinary memory +// fences establishing ordering between memory accesses before and after the +// fence (std::atomic_thread_fence). In all other test cases, T1fence() is not a +// memory fence at all, but only prevents compiler reordering of memory accesses +// (std::atomic_signal_fence); T2fence() is an invocation of the membarrier() +// syscall, which establishes ordering of memory accesses before and after the +// syscall on both threads. + +template <typename F> +int DoMembarrierTestSide(std::atomic<int>* our_var, + std::atomic<int> const& their_var, + F const& test_fence) { + our_var->store(1, std::memory_order_relaxed); + test_fence(); + return their_var.load(std::memory_order_relaxed); +} + +struct MembarrierTestSharedState { + std::atomic<int64_t> remote_iter_cur; + std::atomic<int64_t> remote_iter_done; + std::atomic<int> local_var; + std::atomic<int> remote_var; + int remote_obs_of_local_var; + + void Init() { + remote_iter_cur.store(-1, std::memory_order_relaxed); + remote_iter_done.store(-1, std::memory_order_relaxed); + } +}; + +// Special value for MembarrierTestSharedState::remote_iter_cur indicating that +// the remote thread should terminate. +constexpr int64_t kRemoteIterStop = -2; + +// Must be async-signal-safe. +template <typename F> +void RunMembarrierTestRemoteSide(MembarrierTestSharedState* state, + F const& test_fence) { + int64_t i = 0; + int64_t cur; + while (true) { + while ((cur = state->remote_iter_cur.load(std::memory_order_acquire)) < i) { + if (cur == kRemoteIterStop) { + return; + } + // spin + } + state->remote_obs_of_local_var = + DoMembarrierTestSide(&state->remote_var, state->local_var, test_fence); + state->remote_iter_done.store(i, std::memory_order_release); + i++; + } +} + +template <typename F> +void RunMembarrierTestLocalSide(MembarrierTestSharedState* state, + F const& test_fence) { + // On test completion, instruct the remote thread to terminate. + Cleanup cleanup_remote([&] { + state->remote_iter_cur.store(kRemoteIterStop, std::memory_order_relaxed); + }); + + int64_t i = 0; + absl::Time end = absl::Now() + absl::Seconds(5); // arbitrary test duration + while (absl::Now() < end) { + // Reset both vars to 0. + state->local_var.store(0, std::memory_order_relaxed); + state->remote_var.store(0, std::memory_order_relaxed); + // Instruct the remote thread to begin this iteration. + state->remote_iter_cur.store(i, std::memory_order_release); + // Perform our side of the test. + auto local_obs_of_remote_var = + DoMembarrierTestSide(&state->local_var, state->remote_var, test_fence); + // Wait for the remote thread to finish this iteration. + while (state->remote_iter_done.load(std::memory_order_acquire) < i) { + // spin + } + ASSERT_TRUE(local_obs_of_remote_var != 0 || + state->remote_obs_of_local_var != 0); + i++; + } +} + +TEST(MembarrierTest, NoMembarrier) { + MembarrierTestSharedState state; + state.Init(); + + ScopedThread remote_thread([&] { + RunMembarrierTestRemoteSide( + &state, [] { std::atomic_thread_fence(std::memory_order_seq_cst); }); + }); + RunMembarrierTestLocalSide( + &state, [] { std::atomic_thread_fence(std::memory_order_seq_cst); }); +} + +enum membarrier_cmd { + MEMBARRIER_CMD_QUERY = 0, + MEMBARRIER_CMD_GLOBAL = (1 << 0), + MEMBARRIER_CMD_GLOBAL_EXPEDITED = (1 << 1), + MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED = (1 << 2), + MEMBARRIER_CMD_PRIVATE_EXPEDITED = (1 << 3), + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED = (1 << 4), +}; + +int membarrier(membarrier_cmd cmd, int flags) { + return syscall(SYS_membarrier, cmd, flags); +} + +PosixErrorOr<int> SupportedMembarrierCommands() { + int cmds = membarrier(MEMBARRIER_CMD_QUERY, 0); + if (cmds < 0) { + if (errno == ENOSYS) { + // No commands are supported. + return 0; + } + return PosixError(errno, "membarrier(MEMBARRIER_CMD_QUERY) failed"); + } + return cmds; +} + +TEST(MembarrierTest, Global) { + SKIP_IF((ASSERT_NO_ERRNO_AND_VALUE(SupportedMembarrierCommands()) & + MEMBARRIER_CMD_GLOBAL) == 0); + + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED)); + auto state = static_cast<MembarrierTestSharedState*>(m.ptr()); + state->Init(); + + pid_t const child_pid = fork(); + if (child_pid == 0) { + // In child process. + RunMembarrierTestRemoteSide( + state, [] { TEST_PCHECK(membarrier(MEMBARRIER_CMD_GLOBAL, 0) == 0); }); + _exit(0); + } + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + Cleanup cleanup_child([&] { + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) + << " status " << status; + }); + RunMembarrierTestLocalSide( + state, [] { std::atomic_signal_fence(std::memory_order_seq_cst); }); +} + +TEST(MembarrierTest, GlobalExpedited) { + constexpr int kRequiredCommands = MEMBARRIER_CMD_GLOBAL_EXPEDITED | + MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED; + SKIP_IF((ASSERT_NO_ERRNO_AND_VALUE(SupportedMembarrierCommands()) & + kRequiredCommands) != kRequiredCommands); + + ASSERT_THAT(membarrier(MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, 0), + SyscallSucceeds()); + + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED)); + auto state = static_cast<MembarrierTestSharedState*>(m.ptr()); + state->Init(); + + pid_t const child_pid = fork(); + if (child_pid == 0) { + // In child process. + RunMembarrierTestRemoteSide(state, [] { + TEST_PCHECK(membarrier(MEMBARRIER_CMD_GLOBAL_EXPEDITED, 0) == 0); + }); + _exit(0); + } + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + Cleanup cleanup_child([&] { + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) + << " status " << status; + }); + RunMembarrierTestLocalSide( + state, [] { std::atomic_signal_fence(std::memory_order_seq_cst); }); +} + +TEST(MembarrierTest, PrivateExpedited) { + constexpr int kRequiredCommands = MEMBARRIER_CMD_PRIVATE_EXPEDITED | + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED; + SKIP_IF((ASSERT_NO_ERRNO_AND_VALUE(SupportedMembarrierCommands()) & + kRequiredCommands) != kRequiredCommands); + + ASSERT_THAT(membarrier(MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, 0), + SyscallSucceeds()); + + MembarrierTestSharedState state; + state.Init(); + + ScopedThread remote_thread([&] { + RunMembarrierTestRemoteSide(&state, [] { + TEST_PCHECK(membarrier(MEMBARRIER_CMD_PRIVATE_EXPEDITED, 0) == 0); + }); + }); + RunMembarrierTestLocalSide( + &state, [] { std::atomic_signal_fence(std::memory_order_seq_cst); }); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index c1488b06b..e8fcc4439 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -47,6 +47,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_set.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" @@ -721,8 +722,8 @@ static void CheckFdDirGetdentsDuplicates(const std::string& path) { EXPECT_GE(newfd, 1024); auto fd_closer = Cleanup([newfd]() { close(newfd); }); auto fd_files = ASSERT_NO_ERRNO_AND_VALUE(ListDir(path.c_str(), false)); - std::unordered_set<std::string> fd_files_dedup(fd_files.begin(), - fd_files.end()); + absl::node_hash_set<std::string> fd_files_dedup(fd_files.begin(), + fd_files.end()); EXPECT_EQ(fd_files.size(), fd_files_dedup.size()); } @@ -779,8 +780,12 @@ TEST(ProcSelfFdInfo, Flags) { } TEST(ProcSelfExe, Absolute) { - auto exe = ASSERT_NO_ERRNO_AND_VALUE( - ReadLink(absl::StrCat("/proc/", getpid(), "/exe"))); + auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe")); + EXPECT_EQ(exe[0], '/'); +} + +TEST(ProcSelfCwd, Absolute) { + auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/cwd")); EXPECT_EQ(exe[0], '/'); } @@ -1472,6 +1477,16 @@ TEST(ProcPidExe, Subprocess) { EXPECT_EQ(actual, expected_absolute_path); } +// /proc/PID/cwd points to the correct directory. +TEST(ProcPidCwd, Subprocess) { + auto want = ASSERT_NO_ERRNO_AND_VALUE(GetCWD()); + + char got[PATH_MAX + 1] = {}; + ASSERT_THAT(ReadlinkWhileRunning("cwd", got, sizeof(got)), + SyscallSucceedsWithValue(Gt(0))); + EXPECT_EQ(got, want); +} + // Test whether /proc/PID/ files can be read for a running process. TEST(ProcPidFile, SubprocessRunning) { char buf[1]; diff --git a/test/util/BUILD b/test/util/BUILD index fc5fb3a8d..26c2b6a2f 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") +load("//tools:defs.bzl", "cc_library", "cc_test", "coreutil", "gbenchmark", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -254,7 +254,7 @@ cc_library( ], hdrs = ["test_util.h"], defines = select_system(), - deps = [ + deps = coreutil() + [ ":fs_util", ":logging", ":posix_error", diff --git a/tools/bazel.mk b/tools/bazel.mk index 4235c36ca..25575c02c 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -19,6 +19,7 @@ SHELL=/bin/bash -o pipefail BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ xargs -n 1 basename 2>/dev/null) +BUILD_ROOT := $(CURDIR)/bazel-bin/ # Bazel container configuration (see below). USER ?= gvisor @@ -151,8 +152,9 @@ build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefai build_paths = $(build_cmd) 2>&1 \ | tee /proc/self/fd/2 \ - | grep -E "^ bazel-bin/" \ - | tr -d '\r' \ + | grep " bazel-bin/" \ + | sed "s/ /\n/g" \ + | strings -n 10 \ | awk '{$$1=$$1};1' \ | xargs -n 1 -I {} sh -c "$(1)" diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index dad5fc3b2..cf5b1dc0d 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -191,3 +191,6 @@ def default_installer(): def default_net_util(): return [] # Nothing needed. + +def coreutil(): + return [] # Nothing needed. diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index d98f5c3a1..f5bba9980 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -398,7 +398,37 @@ func loadObjdump() (map[string][]string, error) { return nil, err } + // Identify calls by address or name. Note that this is also + // constructed dynamically below, as we encounted the addresses. + // This is because some of the functions (duffzero) may have + // jump targets in the middle of the function itself. + funcsAllowed := map[string]struct{}{ + "runtime.duffzero": struct{}{}, + "runtime.duffcopy": struct{}{}, + "runtime.racefuncenter": struct{}{}, + "runtime.gcWriteBarrier": struct{}{}, + "runtime.retpolineAX": struct{}{}, + "runtime.retpolineBP": struct{}{}, + "runtime.retpolineBX": struct{}{}, + "runtime.retpolineCX": struct{}{}, + "runtime.retpolineDI": struct{}{}, + "runtime.retpolineDX": struct{}{}, + "runtime.retpolineR10": struct{}{}, + "runtime.retpolineR11": struct{}{}, + "runtime.retpolineR12": struct{}{}, + "runtime.retpolineR13": struct{}{}, + "runtime.retpolineR14": struct{}{}, + "runtime.retpolineR15": struct{}{}, + "runtime.retpolineR8": struct{}{}, + "runtime.retpolineR9": struct{}{}, + "runtime.retpolineSI": struct{}{}, + "runtime.stackcheck": struct{}{}, + "runtime.settls": struct{}{}, + } + addrsAllowed := make(map[string]struct{}) + // Build the map. + nextFunc := "" // For funcsAllowed. m := make(map[string][]string) r := bufio.NewReader(out) NextLine: @@ -407,6 +437,19 @@ NextLine: if err != nil && err != io.EOF { return nil, err } + fields := strings.Fields(line) + + // Is this an "allowed" function definition? + if len(fields) >= 2 && fields[0] == "TEXT" { + nextFunc = strings.TrimSuffix(fields[1], "(SB)") + if _, ok := funcsAllowed[nextFunc]; !ok { + nextFunc = "" // Don't record addresses. + } + } + if nextFunc != "" && len(fields) > 2 { + // Save the given address (in hex form, as it appears). + addrsAllowed[fields[1]] = struct{}{} + } // We recognize lines corresponding to actual code (not the // symbol name or other metadata) and annotate them if they @@ -416,53 +459,31 @@ NextLine: // // Lines look like this (including the first space): // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX - if len(line) > 0 && line[0] == ' ' { - fields := strings.Fields(line) + if len(fields) >= 5 && line[0] == ' ' { if !strings.Contains(fields[3], "CALL") { continue } - site := strings.TrimSpace(fields[0]) - var callStr string // Friendly string. - if len(fields) > 5 { - callStr = strings.Join(fields[5:], " ") - } - if len(callStr) == 0 { - // Just a raw call? is this asm? - callStr = strings.Join(fields[3:], " ") - } - - // Ignore strings containing duffzero, which is just - // used by stack allocations for types that are large - // enough to warrant Duff's device. - if strings.Contains(callStr, "runtime.duffzero") || - strings.Contains(callStr, "runtime.duffcopy") { - continue - } - - // Ignore the racefuncenter call, which is used for - // race builds. This does not escape. - if strings.Contains(callStr, "runtime.racefuncenter") { - continue - } + site := fields[0] + target := strings.TrimSuffix(fields[4], "(SB)") - // Ignore the write barriers. - if strings.Contains(callStr, "runtime.gcWriteBarrier") { + // Ignore strings containing allowed functions. + if _, ok := funcsAllowed[target]; ok { continue } - - // Ignore retpolines. - if strings.Contains(callStr, "runtime.retpoline") { + if _, ok := addrsAllowed[target]; ok { continue } - - // Ignore stack sanity check (does not split). - if strings.Contains(callStr, "runtime.stackcheck") { - continue - } - - // Ignore tls functions. - if strings.Contains(callStr, "runtime.settls") { - continue + if len(fields) > 5 { + // This may be a future relocation. Some + // objdump versions describe this differently. + // If it contains any of the functions allowed + // above as a string, we let it go. + softTarget := strings.Join(fields[5:], " ") + for name := range funcsAllowed { + if strings.Contains(softTarget, name) { + continue NextLine + } + } } // Does this exist already? @@ -471,11 +492,11 @@ NextLine: existing = make([]string, 0, 1) } for _, other := range existing { - if callStr == other { + if target == other { continue NextLine } } - existing = append(existing, callStr) + existing = append(existing, target) m[site] = existing // Update. } if err == io.EOF { @@ -483,12 +504,25 @@ NextLine: } } + // Zap any accidental false positives. + final := make(map[string][]string) + for site, calls := range m { + filteredCalls := make([]string, 0, len(calls)) + for _, call := range calls { + if _, ok := addrsAllowed[call]; ok { + continue // Omit this call. + } + filteredCalls = append(filteredCalls, call) + } + final[site] = filteredCalls + } + // Wait for the dump to finish. if err := cmd.Wait(); err != nil { return nil, err } - return m, nil + return final, nil } // poser is a type that implements Pos. diff --git a/tools/defs.bzl b/tools/defs.bzl index 290d564f2..079ab806f 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") load("//tools/nogo:defs.bzl", "nogo_test") @@ -39,6 +39,7 @@ short_path = _short_path rbe_platform = _rbe_platform rbe_toolchain = _rbe_toolchain vdso_linker_option = _vdso_linker_option +coreutil = _coreutil # Platform options. default_platform = _default_platform @@ -214,7 +215,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F for (suffix, _) in marshal_sets.items(): _go_test( name = name + suffix + "_abi_autogen_test", - srcs = [name + suffix + "_abi_autogen_test.go"], + srcs = [ + name + suffix + "_abi_autogen_test.go", + name + suffix + "_abi_autogen_unconditional_test.go", + ], library = ":" + name, deps = marshal_test_deps, **kwargs diff --git a/tools/github/BUILD b/tools/github/BUILD new file mode 100644 index 000000000..aad088d13 --- /dev/null +++ b/tools/github/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "github", + srcs = ["main.go"], + nogo = False, + deps = [ + "//tools/github/nogo", + "//tools/github/reviver", + "@com_github_google_go_github_v28//github:go_default_library", + "@org_golang_x_oauth2//:go_default_library", + ], +) diff --git a/tools/github/main.go b/tools/github/main.go new file mode 100644 index 000000000..7a74dc033 --- /dev/null +++ b/tools/github/main.go @@ -0,0 +1,162 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary github is the entry point for GitHub utilities. +package main + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "os" + "os/exec" + "strings" + + "github.com/google/go-github/github" + "golang.org/x/oauth2" + "gvisor.dev/gvisor/tools/github/nogo" + "gvisor.dev/gvisor/tools/github/reviver" +) + +var ( + owner string + repo string + tokenFile string + path string + commit string + dryRun bool +) + +// Keep the options simple for now. Supports only a single path and repo. +func init() { + flag.StringVar(&owner, "owner", "", "GitHub project org/owner (required, except nogo dry-run)") + flag.StringVar(&repo, "repo", "", "GitHub repo (required, except nogo dry-run)") + flag.StringVar(&tokenFile, "oauth-token-file", "", "file containing the GitHub token (or GITHUB_TOKEN is set)") + flag.StringVar(&path, "path", ".", "path to scan (required for revive and nogo)") + flag.StringVar(&commit, "commit", "", "commit to associated (required for nogo, except dry-run)") + flag.BoolVar(&dryRun, "dry-run", false, "just print changes to be made") +} + +func main() { + // Set defaults from the environment. + repository := os.Getenv("GITHUB_REPOSITORY") + if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 { + owner = parts[0] + repo = parts[1] + } + + // Parse flags. + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "usage: %s [options] <command>\n", os.Args[0]) + fmt.Fprintf(flag.CommandLine.Output(), "commands: revive, nogo\n") + flag.PrintDefaults() + } + flag.Parse() + args := flag.Args() + if len(args) != 1 { + fmt.Fprintf(flag.CommandLine.Output(), "extra arguments: %s\n", strings.Join(args[1:], ", ")) + flag.Usage() + os.Exit(1) + } + + // Check for mandatory parameters. + command := args[0] + if len(owner) == 0 && (command != "nogo" || !dryRun) { + fmt.Fprintln(flag.CommandLine.Output(), "missing --owner option.") + flag.Usage() + os.Exit(1) + } + if len(repo) == 0 && (command != "nogo" || !dryRun) { + fmt.Fprintln(flag.CommandLine.Output(), "missing --repo option.") + flag.Usage() + os.Exit(1) + } + if len(path) == 0 { + fmt.Fprintln(flag.CommandLine.Output(), "missing --path option.") + flag.Usage() + os.Exit(1) + } + + // The access token may be passed as a file so it doesn't show up in + // command line arguments. It also may be provided through the + // environment to faciliate use through GitHub's CI system. + token := os.Getenv("GITHUB_TOKEN") + if len(tokenFile) != 0 { + bytes, err := ioutil.ReadFile(tokenFile) + if err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } + token = string(bytes) + } + var client *github.Client + if len(token) == 0 { + // Client is unauthenticated. + client = github.NewClient(nil) + } else { + // Using the above token. + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: token}, + ) + tc := oauth2.NewClient(context.Background(), ts) + client = github.NewClient(tc) + } + + switch command { + case "revive": + // Load existing GitHub bugs. + bugger, err := reviver.NewGitHubBugger(client, owner, repo, dryRun) + if err != nil { + fmt.Fprintf(os.Stderr, "Error getting github issues: %v\n", err) + os.Exit(1) + } + // Scan the provided path. + rev := reviver.New([]string{path}, []reviver.Bugger{bugger}) + if errs := rev.Run(); len(errs) > 0 { + fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs)) + for _, err := range errs { + fmt.Fprintf(os.Stderr, "\t%v\n", err) + } + os.Exit(1) + } + case "nogo": + // Did we get a commit? Try to extract one. + if len(commit) == 0 && !dryRun { + cmd := exec.Command("git", "rev-parse", "HEAD") + revBytes, err := cmd.Output() + if err != nil { + fmt.Fprintf(flag.CommandLine.Output(), "missing --commit option, unable to infer: %v\n", err) + flag.Usage() + os.Exit(1) + } + commit = strings.TrimSpace(string(revBytes)) + } + // Scan all findings. + poster := nogo.NewFindingsPoster(client, owner, repo, commit, dryRun) + if err := poster.Walk(path); err != nil { + fmt.Fprintln(os.Stderr, "Error finding nogo findings:", err) + os.Exit(1) + } + // Post to GitHub. + if err := poster.Post(); err != nil { + fmt.Fprintln(os.Stderr, "Error posting nogo findings:", err) + } + default: + // Not a known command. + fmt.Fprintf(flag.CommandLine.Output(), "unknown command: %s\n", command) + flag.Usage() + os.Exit(1) + } +} diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD new file mode 100644 index 000000000..0633eaf19 --- /dev/null +++ b/tools/github/nogo/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "nogo", + srcs = ["nogo.go"], + nogo = False, + visibility = [ + "//tools/github:__subpackages__", + ], + deps = [ + "//tools/nogo/util", + "@com_github_google_go_github_v28//github:go_default_library", + ], +) diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go new file mode 100644 index 000000000..b70dfe63b --- /dev/null +++ b/tools/github/nogo/nogo.go @@ -0,0 +1,126 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nogo provides nogo-related utilities. +package nogo + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/go-github/github" + "gvisor.dev/gvisor/tools/nogo/util" +) + +// FindingsPoster is a simple wrapper around the GitHub api. +type FindingsPoster struct { + owner string + repo string + commit string + dryRun bool + startTime time.Time + + findings map[util.Finding]struct{} + client *github.Client +} + +// NewFindingsPoster returns a object that can post findings. +func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun bool) *FindingsPoster { + return &FindingsPoster{ + owner: owner, + repo: repo, + commit: commit, + dryRun: dryRun, + startTime: time.Now(), + findings: make(map[util.Finding]struct{}), + client: client, + } +} + +// Walk walks the given path tree for findings files. +func (p *FindingsPoster) Walk(path string) error { + return filepath.Walk(path, func(filename string, info os.FileInfo, err error) error { + if err != nil { + return err + } + // Skip any directories or files not ending in .findings. + if !strings.HasSuffix(filename, ".findings") || info.IsDir() { + return nil + } + findings, err := util.ExtractFindingsFromFile(filename) + if err != nil { + return err + } + // Add all findings to the list. We use a map to ensure + // that each finding is unique. + for _, finding := range findings { + p.findings[finding] = struct{}{} + } + return nil + }) +} + +// Post posts all results to the GitHub API as a check run. +func (p *FindingsPoster) Post() error { + // Just show results? + if p.dryRun { + for finding, _ := range p.findings { + // Pretty print, so that this is useful for debugging. + fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Path, finding.Line, finding.Message) + } + return nil + } + + // Construct the message. + title := "nogo" + count := len(p.findings) + status := "completed" + conclusion := "success" + if count > 0 { + conclusion = "failure" // Contains errors. + } + summary := fmt.Sprintf("%d findings.", count) + opts := github.CreateCheckRunOptions{ + Name: title, + HeadSHA: p.commit, + Status: &status, + Conclusion: &conclusion, + StartedAt: &github.Timestamp{p.startTime}, + CompletedAt: &github.Timestamp{time.Now()}, + Output: &github.CheckRunOutput{ + Title: &title, + Summary: &summary, + AnnotationsCount: &count, + }, + } + annotationLevel := "failure" // Always. + for finding, _ := range p.findings { + opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{ + Path: &finding.Path, + StartLine: &finding.Line, + EndLine: &finding.Line, + Message: &finding.Message, + Title: &finding.Category, + AnnotationLevel: &annotationLevel, + }) + } + + // Post to GitHub. + _, _, err := p.client.Checks.CreateCheckRun(context.Background(), p.owner, p.repo, opts) + return err +} diff --git a/tools/github/reviver/BUILD b/tools/github/reviver/BUILD new file mode 100644 index 000000000..7d78480a7 --- /dev/null +++ b/tools/github/reviver/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "reviver", + srcs = [ + "github.go", + "reviver.go", + ], + nogo = False, + visibility = [ + "//tools/github:__subpackages__", + ], + deps = ["@com_github_google_go_github_v28//github:go_default_library"], +) + +go_test( + name = "reviver_test", + size = "small", + srcs = [ + "github_test.go", + "reviver_test.go", + ], + library = ":reviver", + nogo = False, +) diff --git a/tools/issue_reviver/github/github.go b/tools/github/reviver/github.go index 8ffd7e606..a95df0fb6 100644 --- a/tools/issue_reviver/github/github.go +++ b/tools/github/reviver/github.go @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package github implements reviver.Bugger interface on top of Github issues. -package github +package reviver import ( "context" @@ -23,12 +22,10 @@ import ( "time" "github.com/google/go-github/github" - "golang.org/x/oauth2" - "gvisor.dev/gvisor/tools/issue_reviver/reviver" ) -// Bugger implements reviver.Bugger interface for github issues. -type Bugger struct { +// GitHubBugger implements Bugger interface for github issues. +type GitHubBugger struct { owner string repo string dryRun bool @@ -37,36 +34,25 @@ type Bugger struct { issues map[int]*github.Issue } -// NewBugger creates a new Bugger. -func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) { - b := &Bugger{ +// NewGitHubBugger creates a new GitHubBugger. +func NewGitHubBugger(client *github.Client, owner, repo string, dryRun bool) (*GitHubBugger, error) { + b := &GitHubBugger{ owner: owner, repo: repo, dryRun: dryRun, issues: map[int]*github.Issue{}, + client: client, } - if err := b.load(token); err != nil { + if err := b.load(); err != nil { return nil, err } return b, nil } -func (b *Bugger) load(token string) error { - ctx := context.Background() - if len(token) == 0 { - fmt.Print("No OAUTH token provided, using unauthenticated account.\n") - b.client = github.NewClient(nil) - } else { - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: token}, - ) - tc := oauth2.NewClient(ctx, ts) - b.client = github.NewClient(tc) - } - +func (b *GitHubBugger) load() error { err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) { opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts} - tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts) + tmps, resp, err := b.client.Issues.ListByRepo(context.Background(), b.owner, b.repo, opts) if err != nil { return resp, err } @@ -83,8 +69,8 @@ func (b *Bugger) load(token string) error { return nil } -// Activate implements reviver.Bugger. -func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) { +// Activate implements Bugger.Activate. +func (b *GitHubBugger) Activate(todo *Todo) (bool, error) { id, err := parseIssueNo(todo.Issue) if err != nil { return true, err diff --git a/tools/issue_reviver/github/github_test.go b/tools/github/reviver/github_test.go index a78b230ef..5df7e3624 100644 --- a/tools/issue_reviver/github/github_test.go +++ b/tools/github/reviver/github_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package github +package reviver import ( "testing" diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/github/reviver/reviver.go index 2af7f0d59..2af7f0d59 100644 --- a/tools/issue_reviver/reviver/reviver.go +++ b/tools/github/reviver/reviver.go diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/github/reviver/reviver_test.go index a9fb1f9f1..a9fb1f9f1 100644 --- a/tools/issue_reviver/reviver/reviver_test.go +++ b/tools/github/reviver/reviver_test.go diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 75e5c7888..d8045c295 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -113,3 +113,18 @@ The following are some guidelines for modifying the `go_marshal` tool: - No runtime reflection in the code generated for the marshallable interface. The entire point of the tool is to avoid runtime reflection. The generated tests may use reflection. + +## Debugging + +To enable debugging output from the go-marshal tool, use one of the following +options, depending on how go-marshal is being invoked: + +- Pass `--define gomarshal=verbose` to the bazel command. Note that this can + generate a lot of output depending on what's being compiled, as this will + enable debugging for all packages built by the command. + +- Set `marshal_debug = True` on the top-level `go_library` BUILD rule. + +- Set `debug = True` on the `go_marshal` BUILD rule. + +- Pass `-debug` to the go-marshal tool invocation. diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index ba98f3599..f44f83eab 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -4,11 +4,13 @@ def _go_marshal_impl(ctx): """Execute the go_marshal tool.""" output = ctx.outputs.lib output_test = ctx.outputs.test + output_test_unconditional = ctx.outputs.test_unconditional # Run the marshal command. args = ["-output=%s" % output.path] - args += ["-pkg=%s" % ctx.attr.package] - args += ["-output_test=%s" % output_test.path] + args.append("-pkg=%s" % ctx.attr.package) + args.append("-output_test=%s" % output_test.path) + args.append("-output_test_unconditional=%s" % output_test_unconditional.path) if ctx.attr.debug: args += ["-debug"] @@ -18,7 +20,7 @@ def _go_marshal_impl(ctx): args += [f.path for f in src.files.to_list()] ctx.actions.run( inputs = ctx.files.srcs, - outputs = [output, output_test], + outputs = [output, output_test, output_test_unconditional], mnemonic = "GoMarshal", progress_message = "go_marshal: %s" % ctx.label, arguments = args, @@ -48,6 +50,7 @@ go_marshal = rule( outputs = { "lib": "%{name}_unsafe.go", "test": "%{name}_test.go", + "test_unconditional": "%{name}_unconditional_test.go", }, ) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 72ed6d109..56fbcb5d2 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -68,6 +68,8 @@ type Generator struct { output *os.File // Output file to write generated tests. outputTest *os.File + // Output file to write unconditionally generated tests. + outputTestUC *os.File // Package name for the generated file. pkg string // Set of extra packages to import in the generated file. @@ -75,7 +77,7 @@ type Generator struct { } // NewGenerator creates a new code Generator. -func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { +func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) { f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) @@ -84,12 +86,17 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G if err != nil { return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) } + fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open unconditional test output file %q: %v", out, err) + } g := Generator{ - inputs: srcs, - output: f, - outputTest: fTest, - pkg: pkg, - imports: newImportTable(), + inputs: srcs, + output: f, + outputTest: fTest, + outputTestUC: fTestUC, + pkg: pkg, + imports: newImportTable(), } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as @@ -454,6 +461,46 @@ func (g *Generator) Run() error { // source file. func (g *Generator) writeTests(ts []*testGenerator) error { var b sourceBuffer + + // Write the unconditional test file. This file is always compiled, + // regardless of what build tags were specified on the original input + // files. We use this file to guarantee we never end up with an empty test + // file, as that causes the build to fail with "no tests/benchmarks/examples + // found". + // + // There's no easy way to determine ahead of time if we'll end up with an + // empty build file since build constraints can arbitrarily cause some of + // the original types to be not defined. We also have no way to tell bazel + // to omit the entire test suite since the output files are already defined + // before go-marshal is called. + b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") + b.emit("package %s\n\n", g.pkg) + b.emit("func Example() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty, and ensures this package contains at\n") + b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n") + b.emit("// input package is marked marshallable, but emitting no testable entities \n") + b.emit("// results in a build failure.\n") + }) + b.emit("}\n") + if err := b.write(g.outputTestUC); err != nil { + return err + } + + // Now generate the real test file that contains the real types we + // processed. These need to be conditionally compiled according to the build + // tags, as the original types may not be defined under all build + // configurations. + + b.reset() + b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") + + // Emit build tags. + if t := tags.Aggregate(g.inputs); len(t) > 0 { + b.emit(strings.Join(t.Lines(), "\n")) + b.emit("\n\n") + } + b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { return err @@ -470,26 +517,6 @@ func (g *Generator) writeTests(ts []*testGenerator) error { } // Write test functions. - - // If we didn't generate any Marshallable implementations, we can't just - // emit an empty test file, since that causes the build to fail with "no - // tests/benchmarks/examples found". Unfortunately we can't signal bazel to - // omit the entire package since the outputs are already defined before - // go-marshal is called. If we'd otherwise emit an empty test suite, emit an - // empty example instead. - if len(ts) == 0 { - b.reset() - b.emit("func Example() {\n") - b.inIndent(func() { - b.emit("// This example is intentionally empty to ensure this file contains at least\n") - b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") - b.emit("// is marked marshallable, but emitting a test file with no entities results\n") - b.emit("// in a build failure.\n") - }) - b.emit("}\n") - return b.write(g.outputTest) - } - for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index cf76b5241..36447b86b 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -43,8 +43,8 @@ type interfaceGenerator struct { // of t's interfaces. ms map[string]struct{} - // as records embedded fields in t that are potentially not packed. The key - // is the accessor for the field. + // as records fields in t that are potentially not packed. The key is the + // accessor for the field. as map[string]struct{} } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 456662fab..fe76d3785 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -51,12 +51,6 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { // later. func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { forEachStructField(st, func(f *ast.Field) { - if len(f.Names) == 0 { - g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") - } - }) - - forEachStructField(st, func(f *ast.Field) { fieldDispatcher{ primitive: func(_, t *ast.Ident) { g.validatePrimitiveNewtype(t) @@ -101,7 +95,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { var dynamicSizeTerms []string forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { + primitive: func(_, t *ast.Ident) { if size, dynamic := g.scalarSize(t); !dynamic { primitiveSize += size } else { @@ -109,13 +103,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) } }, - selector: func(n, tX, tSel *ast.Ident) { + selector: func(_, tX, tSel *ast.Ident) { tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) g.recordUsedImport(tX.Name) g.recordUsedMarshallable(tName) dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) { lenExpr := g.arrayLenExpr(a) if size, dynamic := g.scalarSize(t); !dynamic { dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index d94314302..6a42691cd 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -79,7 +79,7 @@ type fieldDispatcher struct { } // Precondition: All dispatch callbacks that will be invoked must be -// provided. Embedded fields are not allowed, len(f.Names) >= 1. +// provided. func (fd fieldDispatcher) dispatch(f *ast.Field) { // Each field declaration may actually be multiple declarations of the same // type. For example, consider: @@ -88,12 +88,24 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { // x, y, z int // } // - // We invoke the call-backs once per such instance. Embedded fields are not - // allowed, and results in a panic. + // We invoke the call-backs once per such instance. + + // Handle embedded fields. Embedded fields have no names, but can be + // referenced by the type name. if len(f.Names) < 1 { - panic("Precondition not met: attempted to dispatch on embedded field") + switch v := f.Type.(type) { + case *ast.Ident: + fd.primitive(v, v) + case *ast.SelectorExpr: + fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel) + default: + // Note: Arrays can't be embedded, which is handled here. + panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type)) + } + return } + // Non-embedded field. for _, name := range f.Names { switch v := f.Type.(type) { case *ast.Ident: diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go index f74be5c29..6e4a3e8c4 100644 --- a/tools/go_marshal/main.go +++ b/tools/go_marshal/main.go @@ -31,10 +31,11 @@ import ( ) var ( - pkg = flag.String("pkg", "", "output package") - output = flag.String("output", "", "output file") - outputTest = flag.String("output_test", "", "output file for tests") - imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + outputTestUnconditional = flag.String("output_test_unconditional", "", "output file for unconditional tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") ) func main() { @@ -61,7 +62,7 @@ func main() { // as an import. extraImports = strings.Split(*imports, ",") } - g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *outputTestUnconditional, *pkg, extraImports) if err != nil { panic(err) } diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index f75ca1b7f..d9e9f341b 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -174,3 +174,27 @@ type Type9 struct { x int64 y [sizeA]int32 } + +// Type10Embed is a test data type which is be embedded into another type. +// +// +marshal +type Type10Embed struct { + x int64 +} + +// Type10 is a test data type which contains an embedded struct. +// +// +marshal +type Type10 struct { + Type10Embed + y int64 +} + +// Type11 is a test data type which contains an embedded struct from an external +// package. +// +// +marshal +type Type11 struct { + ex.External + y int64 +} diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD deleted file mode 100644 index 35b0111ca..000000000 --- a/tools/issue_reviver/BUILD +++ /dev/null @@ -1,13 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "issue_reviver", - srcs = ["main.go"], - nogo = False, - deps = [ - "//tools/issue_reviver/github", - "//tools/issue_reviver/reviver", - ], -) diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD deleted file mode 100644 index 555abd296..000000000 --- a/tools/issue_reviver/github/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "github", - srcs = ["github.go"], - nogo = False, - visibility = [ - "//tools/issue_reviver:__subpackages__", - ], - deps = [ - "//tools/issue_reviver/reviver", - "@com_github_google_go_github_v28//github:go_default_library", - "@org_golang_x_oauth2//:go_default_library", - ], -) - -go_test( - name = "github_test", - size = "small", - srcs = ["github_test.go"], - library = ":github", - nogo = False, -) diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go deleted file mode 100644 index 47c796b8a..000000000 --- a/tools/issue_reviver/main.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package main is the entry point for issue_reviver. -package main - -import ( - "flag" - "fmt" - "io/ioutil" - "os" - "strings" - - "gvisor.dev/gvisor/tools/issue_reviver/github" - "gvisor.dev/gvisor/tools/issue_reviver/reviver" -) - -var ( - owner string - repo string - tokenFile string - path string - dryRun bool -) - -// Keep the options simple for now. Supports only a single path and repo. -func init() { - flag.StringVar(&owner, "owner", "", "Github project org/owner to look for issues") - flag.StringVar(&repo, "repo", "", "Github repo to look for issues") - flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github") - flag.StringVar(&path, "path", ".", "Path to scan for TODOs") - flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues") -} - -func main() { - // Set defaults from the environment. - repository := os.Getenv("GITHUB_REPOSITORY") - if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 { - owner = parts[0] - repo = parts[1] - } - - // Parse flags. - flag.Parse() - - // Check for mandatory parameters. - if len(owner) == 0 { - fmt.Println("missing --owner option.") - flag.Usage() - os.Exit(1) - } - if len(repo) == 0 { - fmt.Println("missing --repo option.") - flag.Usage() - os.Exit(1) - } - if len(path) == 0 { - fmt.Println("missing --path option.") - flag.Usage() - os.Exit(1) - } - - // The access token may be passed as a file so it doesn't show up in - // command line arguments. It also may be provided through the - // environment to faciliate use through GitHub's CI system. - token := os.Getenv("GITHUB_TOKEN") - if len(tokenFile) != 0 { - bytes, err := ioutil.ReadFile(tokenFile) - if err != nil { - fmt.Println(err.Error()) - os.Exit(1) - } - token = string(bytes) - } - - bugger, err := github.NewBugger(token, owner, repo, dryRun) - if err != nil { - fmt.Fprintln(os.Stderr, "Error getting github issues:", err) - os.Exit(1) - } - rev := reviver.New([]string{path}, []reviver.Bugger{bugger}) - if errs := rev.Run(); len(errs) > 0 { - fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs)) - for _, err := range errs { - fmt.Fprintf(os.Stderr, "\t%v\n", err) - } - os.Exit(1) - } -} diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD deleted file mode 100644 index d262932bd..000000000 --- a/tools/issue_reviver/reviver/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") - -package(licenses = ["notice"]) - -go_library( - name = "reviver", - srcs = ["reviver.go"], - visibility = [ - "//tools/issue_reviver:__subpackages__", - ], -) - -go_test( - name = "reviver_test", - size = "small", - srcs = ["reviver_test.go"], - library = ":reviver", -) diff --git a/tools/nogo/build.go b/tools/nogo/build.go index 37947b5c3..39c2ae418 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -26,6 +26,9 @@ var ( // and should not have any special prefix applied. internalPrefix = fmt.Sprintf("^") + // internalDefault is applied when no paths are provided. + internalDefault = fmt.Sprintf("%s/.*", notPath("external")) + // externalPrefix is external workspace packages. externalPrefix = "^external/" ) diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go index 57a250501..5c39be630 100644 --- a/tools/nogo/matchers.go +++ b/tools/nogo/matchers.go @@ -16,7 +16,6 @@ package nogo import ( "go/token" - "path/filepath" "regexp" "strings" @@ -44,11 +43,30 @@ type pathRegexps struct { func buildRegexps(prefix string, args ...string) []*regexp.Regexp { result := make([]*regexp.Regexp, 0, len(args)) for _, arg := range args { - result = append(result, regexp.MustCompile(filepath.Join(prefix, arg))) + result = append(result, regexp.MustCompile(prefix+arg)) } return result } +// notPath works around the lack of backtracking. +// +// It is used to construct a regular expression for non-matching components. +func notPath(name string) string { + sb := strings.Builder{} + sb.WriteString("(") + for i := range name { + if i > 0 { + sb.WriteString("|") + } + sb.WriteString(name[:i]) + sb.WriteString("[^") + sb.WriteByte(name[i]) + sb.WriteString("/][^/]*") + } + sb.WriteString(")") + return sb.String() +} + // ShouldReport implements matcher.ShouldReport. func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { fullPos := fs.Position(d.Pos).String() @@ -79,7 +97,7 @@ func externalExcluded(paths ...string) *pathRegexps { // internalMatches returns a path matcher for internal packages. func internalMatches() *pathRegexps { return &pathRegexps{ - expr: buildRegexps(internalPrefix, ".*"), + expr: buildRegexps(internalPrefix, internalDefault), include: true, } } diff --git a/tools/nogo/util/BUILD b/tools/nogo/util/BUILD new file mode 100644 index 000000000..7ab340b51 --- /dev/null +++ b/tools/nogo/util/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "util", + srcs = ["util.go"], + visibility = ["//visibility:public"], +) diff --git a/tools/nogo/util/util.go b/tools/nogo/util/util.go new file mode 100644 index 000000000..919fec799 --- /dev/null +++ b/tools/nogo/util/util.go @@ -0,0 +1,85 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package util contains nogo-related utilities. +package util + +import ( + "fmt" + "io/ioutil" + "regexp" + "strconv" + "strings" +) + +// findingRegexp is used to parse findings. +var findingRegexp = regexp.MustCompile(`([a-zA-Z0-9_\/\.-]+): (-|([a-zA-Z0-9_\/\.-]+):([0-9]+)(:([0-9]+))?): (.*)`) + +const ( + categoryIndex = 1 + fullPathAndLineIndex = 2 + fullPathIndex = 3 + lineIndex = 4 + messageIndex = 7 +) + +// Finding is a single finding. +type Finding struct { + Category string + Path string + Line int + Message string +} + +// ExtractFindingsFromFile loads findings from a file. +func ExtractFindingsFromFile(filename string) ([]Finding, error) { + content, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + return ExtractFindingsFromBytes(content) +} + +// ExtractFindingsFromBytes loads findings from bytes. +func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) { + lines := strings.Split(string(content), "\n") + for _, singleLine := range lines { + // Skip blank lines. + singleLine = strings.TrimSpace(singleLine) + if singleLine == "" { + continue + } + m := findingRegexp.FindStringSubmatch(singleLine) + if m == nil { + // We shouldn't see findings like this. + return findings, fmt.Errorf("poorly formated line: %v", singleLine) + } + if m[fullPathAndLineIndex] == "-" { + continue // No source file available. + } + // Cleanup the message. + message := m[messageIndex] + message = strings.Replace(message, " → ", "\n → ", -1) + message = strings.Replace(message, " or ", "\n or ", -1) + // Construct a new annotation. + lineNumber, _ := strconv.ParseUint(m[lineIndex], 10, 32) + findings = append(findings, Finding{ + Category: m[categoryIndex], + Path: m[fullPathIndex], + Line: int(lineNumber), + Message: message, + }) + } + return findings, nil +} |